2
0

Merge pull request #515 from fcelda/replication-allow-query

Allow normal queries on replication connections
This commit is contained in:
Jack Christensen
2019-04-02 18:46:41 -05:00
committed by GitHub
2 changed files with 72 additions and 47 deletions
+47 -47
View File
@@ -163,23 +163,35 @@ func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) {
config.RuntimeParams = make(map[string]string) config.RuntimeParams = make(map[string]string)
} }
config.RuntimeParams["replication"] = "database" config.RuntimeParams["replication"] = "database"
config.PreferSimpleProtocol = true
c, err := Connect(config) c, err := Connect(config)
if err != nil { if err != nil {
return return
} }
return &ReplicationConn{c: c}, nil return &ReplicationConn{c}, nil
} }
// ReplicationConn is a PostgreSQL connection handle established in the
// replication mode which enables a special set of commands for streaming WAL
// changes from the server.
//
// When in replication mode, only the simple query protocol can be used
// (see PreferSimpleProtocol in ConnConfig). Execution of normal SQL queries on
// the connection is possible but may be limited in available functionality.
// Most notably, prepared statements won't work.
//
// See https://www.postgresql.org/docs/11/protocol-replication.html for
// details.
type ReplicationConn struct { type ReplicationConn struct {
c *Conn *Conn
} }
// Send standby status to the server, which both acts as a keepalive // Send standby status to the server, which both acts as a keepalive
// message to the server, as well as carries the WAL position of the // message to the server, as well as carries the WAL position of the
// client, which then updates the server's replication slot position. // client, which then updates the server's replication slot position.
func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
buf := rc.c.wbuf buf := rc.wbuf
buf = append(buf, copyData) buf = append(buf, copyData)
sp := len(buf) sp := len(buf)
buf = pgio.AppendInt32(buf, -1) buf = pgio.AppendInt32(buf, -1)
@@ -193,46 +205,34 @@ func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = rc.c.conn.Write(buf) _, err = rc.conn.Write(buf)
if err != nil { if err != nil {
rc.c.die(err) rc.die(err)
} }
return return
} }
func (rc *ReplicationConn) Close() error {
return rc.c.Close()
}
func (rc *ReplicationConn) IsAlive() bool {
return rc.c.IsAlive()
}
func (rc *ReplicationConn) CauseOfDeath() error {
return rc.c.CauseOfDeath()
}
func (rc *ReplicationConn) GetConnInfo() *pgtype.ConnInfo { func (rc *ReplicationConn) GetConnInfo() *pgtype.ConnInfo {
return rc.c.ConnInfo return rc.ConnInfo
} }
func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) {
msg, err := rc.c.rxMsg() msg, err := rc.rxMsg()
if err != nil { if err != nil {
return return
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.NoticeResponse: case *pgproto3.NoticeResponse:
pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg)) pgError := rc.rxErrorResponse((*pgproto3.ErrorResponse)(msg))
if rc.c.shouldLog(LogLevelInfo) { if rc.shouldLog(LogLevelInfo) {
rc.c.log(LogLevelInfo, pgError.Error(), nil) rc.log(LogLevelInfo, pgError.Error(), nil)
} }
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
err = rc.c.rxErrorResponse(msg) err = rc.rxErrorResponse(msg)
if rc.c.shouldLog(LogLevelError) { if rc.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, err.Error(), nil) rc.log(LogLevelError, err.Error(), nil)
} }
return return
case *pgproto3.CopyBothResponse: case *pgproto3.CopyBothResponse:
@@ -269,13 +269,13 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err
h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow} h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow}
return &ReplicationMessage{ServerHeartbeat: h}, nil return &ReplicationMessage{ServerHeartbeat: h}, nil
default: default:
if rc.c.shouldLog(LogLevelError) { if rc.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType}) rc.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType})
} }
} }
default: default:
if rc.c.shouldLog(LogLevelError) { if rc.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg}) rc.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg})
} }
} }
return return
@@ -300,12 +300,12 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
if err := rc.c.conn.SetDeadline(time.Now()); err != nil { if err := rc.conn.SetDeadline(time.Now()); err != nil {
rc.Close() // Close connection if unable to set deadline rc.Close() // Close connection if unable to set deadline
return return
} }
rc.c.closedChan <- ctx.Err() rc.closedChan <- ctx.Err()
case <-rc.c.doneChan: case <-rc.doneChan:
} }
}() }()
@@ -313,8 +313,8 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
var err error var err error
select { select {
case err = <-rc.c.closedChan: case err = <-rc.closedChan:
if err := rc.c.conn.SetDeadline(time.Time{}); err != nil { if err := rc.conn.SetDeadline(time.Time{}); err != nil {
rc.Close() // Close connection if unable to disable deadline rc.Close() // Close connection if unable to disable deadline
return nil, err return nil, err
} }
@@ -322,7 +322,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
if opErr == nil { if opErr == nil {
err = nil err = nil
} }
case rc.c.doneChan <- struct{}{}: case rc.doneChan <- struct{}{}:
err = opErr err = opErr
} }
@@ -330,34 +330,34 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
} }
func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
rc.c.lastActivityTime = time.Now() rc.lastActivityTime = time.Now()
rows := rc.c.getRows(sql, nil) rows := rc.getRows(sql, nil)
if err := rc.c.lock(); err != nil { if err := rc.lock(); err != nil {
rows.fatal(err) rows.fatal(err)
return rows, err return rows, err
} }
rows.unlockConn = true rows.unlockConn = true
err := rc.c.sendSimpleQuery(sql) err := rc.sendSimpleQuery(sql)
if err != nil { if err != nil {
rows.fatal(err) rows.fatal(err)
} }
msg, err := rc.c.rxMsg() msg, err := rc.rxMsg()
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.RowDescription: case *pgproto3.RowDescription:
rows.fields = rc.c.rxRowDescription(msg) rows.fields = rc.rxRowDescription(msg)
// We don't have c.PgTypes here because we're a replication // We don't have c.PgTypes here because we're a replication
// connection. This means the field descriptions will have // connection. This means the field descriptions will have
// only OIDs. Not much we can do about this. // only OIDs. Not much we can do about this.
default: default:
if e := rc.c.processContextFreeMsg(msg); e != nil { if e := rc.processContextFreeMsg(msg); e != nil {
rows.fatal(e) rows.fatal(e)
return rows, e return rows, e
} }
@@ -417,7 +417,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
queryString += fmt.Sprintf(" ( %s )", strings.Join(pluginArguments, ", ")) queryString += fmt.Sprintf(" ( %s )", strings.Join(pluginArguments, ", "))
} }
if err = rc.c.sendQuery(queryString); err != nil { if err = rc.sendQuery(queryString); err != nil {
return return
} }
@@ -431,8 +431,8 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
var r *ReplicationMessage var r *ReplicationMessage
r, err = rc.WaitForReplicationMessage(ctx) r, err = rc.WaitForReplicationMessage(ctx)
if err != nil && r != nil { if err != nil && r != nil {
if rc.c.shouldLog(LogLevelError) { if rc.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err}) rc.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err})
} }
} }
@@ -441,7 +441,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
// Create the replication slot, using the given name and output plugin. // Create the replication slot, using the given name and output plugin.
func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) { func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) {
_, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin)) _, err = rc.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
return return
} }
@@ -459,6 +459,6 @@ func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string
// Drop the replication slot for the given name // Drop the replication slot for the given name
func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) { func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) {
_, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) _, err = rc.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
return return
} }
+25
View File
@@ -343,3 +343,28 @@ func TestStandbyStatusParsing(t *testing.T) {
t.Errorf("Unexpected write position %d", status.WalWritePosition) t.Errorf("Unexpected write position %d", status.WalWritePosition)
} }
} }
func TestSimpleProtocolEnforcement(t *testing.T) {
if replicationConnConfig == nil {
t.Skip("Skipping due to undefined replicationConnConfig")
}
replicationConn := mustReplicationConnect(t, *replicationConnConfig)
defer closeReplicationConn(t, replicationConn)
query := "select count(*) from pg_replication_slots"
// Check that the simple query protocol is used by default
rows, err := replicationConn.Query(query)
if err != nil {
t.Fatalf("Query failed: %v", err)
}
rows.Close()
// Check that using the extended query protocol will fail
rows, err = replicationConn.QueryEx(context.Background(), query, &pgx.QueryExOptions{SimpleProtocol: false})
if err == nil {
t.Fatal("Query expected to fail.")
}
rows.Close()
}