2
0

Merge pull request #515 from fcelda/replication-allow-query

Allow normal queries on replication connections
This commit is contained in:
Jack Christensen
2019-04-02 18:46:41 -05:00
committed by GitHub
2 changed files with 72 additions and 47 deletions
+47 -47
View File
@@ -163,23 +163,35 @@ func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) {
config.RuntimeParams = make(map[string]string)
}
config.RuntimeParams["replication"] = "database"
config.PreferSimpleProtocol = true
c, err := Connect(config)
if err != nil {
return
}
return &ReplicationConn{c: c}, nil
return &ReplicationConn{c}, nil
}
// ReplicationConn is a PostgreSQL connection handle established in the
// replication mode which enables a special set of commands for streaming WAL
// changes from the server.
//
// When in replication mode, only the simple query protocol can be used
// (see PreferSimpleProtocol in ConnConfig). Execution of normal SQL queries on
// the connection is possible but may be limited in available functionality.
// Most notably, prepared statements won't work.
//
// See https://www.postgresql.org/docs/11/protocol-replication.html for
// details.
type ReplicationConn struct {
c *Conn
*Conn
}
// Send standby status to the server, which both acts as a keepalive
// message to the server, as well as carries the WAL position of the
// client, which then updates the server's replication slot position.
func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
buf := rc.c.wbuf
buf := rc.wbuf
buf = append(buf, copyData)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
@@ -193,46 +205,34 @@ func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = rc.c.conn.Write(buf)
_, err = rc.conn.Write(buf)
if err != nil {
rc.c.die(err)
rc.die(err)
}
return
}
func (rc *ReplicationConn) Close() error {
return rc.c.Close()
}
func (rc *ReplicationConn) IsAlive() bool {
return rc.c.IsAlive()
}
func (rc *ReplicationConn) CauseOfDeath() error {
return rc.c.CauseOfDeath()
}
func (rc *ReplicationConn) GetConnInfo() *pgtype.ConnInfo {
return rc.c.ConnInfo
return rc.ConnInfo
}
func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) {
msg, err := rc.c.rxMsg()
msg, err := rc.rxMsg()
if err != nil {
return
}
switch msg := msg.(type) {
case *pgproto3.NoticeResponse:
pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg))
if rc.c.shouldLog(LogLevelInfo) {
rc.c.log(LogLevelInfo, pgError.Error(), nil)
pgError := rc.rxErrorResponse((*pgproto3.ErrorResponse)(msg))
if rc.shouldLog(LogLevelInfo) {
rc.log(LogLevelInfo, pgError.Error(), nil)
}
case *pgproto3.ErrorResponse:
err = rc.c.rxErrorResponse(msg)
if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, err.Error(), nil)
err = rc.rxErrorResponse(msg)
if rc.shouldLog(LogLevelError) {
rc.log(LogLevelError, err.Error(), nil)
}
return
case *pgproto3.CopyBothResponse:
@@ -269,13 +269,13 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err
h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow}
return &ReplicationMessage{ServerHeartbeat: h}, nil
default:
if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType})
if rc.shouldLog(LogLevelError) {
rc.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType})
}
}
default:
if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg})
if rc.shouldLog(LogLevelError) {
rc.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg})
}
}
return
@@ -300,12 +300,12 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
go func() {
select {
case <-ctx.Done():
if err := rc.c.conn.SetDeadline(time.Now()); err != nil {
if err := rc.conn.SetDeadline(time.Now()); err != nil {
rc.Close() // Close connection if unable to set deadline
return
}
rc.c.closedChan <- ctx.Err()
case <-rc.c.doneChan:
rc.closedChan <- ctx.Err()
case <-rc.doneChan:
}
}()
@@ -313,8 +313,8 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
var err error
select {
case err = <-rc.c.closedChan:
if err := rc.c.conn.SetDeadline(time.Time{}); err != nil {
case err = <-rc.closedChan:
if err := rc.conn.SetDeadline(time.Time{}); err != nil {
rc.Close() // Close connection if unable to disable deadline
return nil, err
}
@@ -322,7 +322,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
if opErr == nil {
err = nil
}
case rc.c.doneChan <- struct{}{}:
case rc.doneChan <- struct{}{}:
err = opErr
}
@@ -330,34 +330,34 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
}
func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
rc.c.lastActivityTime = time.Now()
rc.lastActivityTime = time.Now()
rows := rc.c.getRows(sql, nil)
rows := rc.getRows(sql, nil)
if err := rc.c.lock(); err != nil {
if err := rc.lock(); err != nil {
rows.fatal(err)
return rows, err
}
rows.unlockConn = true
err := rc.c.sendSimpleQuery(sql)
err := rc.sendSimpleQuery(sql)
if err != nil {
rows.fatal(err)
}
msg, err := rc.c.rxMsg()
msg, err := rc.rxMsg()
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *pgproto3.RowDescription:
rows.fields = rc.c.rxRowDescription(msg)
rows.fields = rc.rxRowDescription(msg)
// We don't have c.PgTypes here because we're a replication
// connection. This means the field descriptions will have
// only OIDs. Not much we can do about this.
default:
if e := rc.c.processContextFreeMsg(msg); e != nil {
if e := rc.processContextFreeMsg(msg); e != nil {
rows.fatal(e)
return rows, e
}
@@ -417,7 +417,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
queryString += fmt.Sprintf(" ( %s )", strings.Join(pluginArguments, ", "))
}
if err = rc.c.sendQuery(queryString); err != nil {
if err = rc.sendQuery(queryString); err != nil {
return
}
@@ -431,8 +431,8 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
var r *ReplicationMessage
r, err = rc.WaitForReplicationMessage(ctx)
if err != nil && r != nil {
if rc.c.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err})
if rc.shouldLog(LogLevelError) {
rc.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err})
}
}
@@ -441,7 +441,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
// Create the replication slot, using the given name and output plugin.
func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) {
_, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
_, err = rc.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
return
}
@@ -459,6 +459,6 @@ func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string
// Drop the replication slot for the given name
func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) {
_, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
_, err = rc.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
return
}
+25
View File
@@ -343,3 +343,28 @@ func TestStandbyStatusParsing(t *testing.T) {
t.Errorf("Unexpected write position %d", status.WalWritePosition)
}
}
func TestSimpleProtocolEnforcement(t *testing.T) {
if replicationConnConfig == nil {
t.Skip("Skipping due to undefined replicationConnConfig")
}
replicationConn := mustReplicationConnect(t, *replicationConnConfig)
defer closeReplicationConn(t, replicationConn)
query := "select count(*) from pg_replication_slots"
// Check that the simple query protocol is used by default
rows, err := replicationConn.Query(query)
if err != nil {
t.Fatalf("Query failed: %v", err)
}
rows.Close()
// Check that using the extended query protocol will fail
rows, err = replicationConn.QueryEx(context.Background(), query, &pgx.QueryExOptions{SimpleProtocol: false})
if err == nil {
t.Fatal("Query expected to fail.")
}
rows.Close()
}