2
0

Allow normal queries on replication connections

The replication connection allows executing most of the SQL queries
which are available on non-replication connections.
This commit is contained in:
Jan Vcelak
2019-03-07 12:32:26 +01:00
parent 6067cfab4f
commit 1edfd3b682
+38 -38
View File
@@ -168,18 +168,18 @@ func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) {
if err != nil { if err != nil {
return return
} }
return &ReplicationConn{c: c}, nil return &ReplicationConn{c}, nil
} }
type ReplicationConn struct { type ReplicationConn struct {
c *Conn *Conn
} }
// Send standby status to the server, which both acts as a keepalive // Send standby status to the server, which both acts as a keepalive
// message to the server, as well as carries the WAL position of the // message to the server, as well as carries the WAL position of the
// client, which then updates the server's replication slot position. // client, which then updates the server's replication slot position.
func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
buf := rc.c.wbuf buf := rc.wbuf
buf = append(buf, copyData) buf = append(buf, copyData)
sp := len(buf) sp := len(buf)
buf = pgio.AppendInt32(buf, -1) buf = pgio.AppendInt32(buf, -1)
@@ -193,46 +193,46 @@ func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = rc.c.conn.Write(buf) _, err = rc.conn.Write(buf)
if err != nil { if err != nil {
rc.c.die(err) rc.die(err)
} }
return return
} }
func (rc *ReplicationConn) Close() error { func (rc *ReplicationConn) Close() error {
return rc.c.Close() return rc.Close()
} }
func (rc *ReplicationConn) IsAlive() bool { func (rc *ReplicationConn) IsAlive() bool {
return rc.c.IsAlive() return rc.IsAlive()
} }
func (rc *ReplicationConn) CauseOfDeath() error { func (rc *ReplicationConn) CauseOfDeath() error {
return rc.c.CauseOfDeath() return rc.CauseOfDeath()
} }
func (rc *ReplicationConn) GetConnInfo() *pgtype.ConnInfo { func (rc *ReplicationConn) GetConnInfo() *pgtype.ConnInfo {
return rc.c.ConnInfo return rc.ConnInfo
} }
func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) {
msg, err := rc.c.rxMsg() msg, err := rc.rxMsg()
if err != nil { if err != nil {
return return
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.NoticeResponse: case *pgproto3.NoticeResponse:
pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg)) pgError := rc.rxErrorResponse((*pgproto3.ErrorResponse)(msg))
if rc.c.shouldLog(LogLevelInfo) { if rc.shouldLog(LogLevelInfo) {
rc.c.log(LogLevelInfo, pgError.Error(), nil) rc.log(LogLevelInfo, pgError.Error(), nil)
} }
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
err = rc.c.rxErrorResponse(msg) err = rc.rxErrorResponse(msg)
if rc.c.shouldLog(LogLevelError) { if rc.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, err.Error(), nil) rc.log(LogLevelError, err.Error(), nil)
} }
return return
case *pgproto3.CopyBothResponse: case *pgproto3.CopyBothResponse:
@@ -269,13 +269,13 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err
h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow} h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow}
return &ReplicationMessage{ServerHeartbeat: h}, nil return &ReplicationMessage{ServerHeartbeat: h}, nil
default: default:
if rc.c.shouldLog(LogLevelError) { if rc.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType}) rc.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType})
} }
} }
default: default:
if rc.c.shouldLog(LogLevelError) { if rc.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg}) rc.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg})
} }
} }
return return
@@ -300,12 +300,12 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
if err := rc.c.conn.SetDeadline(time.Now()); err != nil { if err := rc.conn.SetDeadline(time.Now()); err != nil {
rc.Close() // Close connection if unable to set deadline rc.Close() // Close connection if unable to set deadline
return return
} }
rc.c.closedChan <- ctx.Err() rc.closedChan <- ctx.Err()
case <-rc.c.doneChan: case <-rc.doneChan:
} }
}() }()
@@ -313,8 +313,8 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
var err error var err error
select { select {
case err = <-rc.c.closedChan: case err = <-rc.closedChan:
if err := rc.c.conn.SetDeadline(time.Time{}); err != nil { if err := rc.conn.SetDeadline(time.Time{}); err != nil {
rc.Close() // Close connection if unable to disable deadline rc.Close() // Close connection if unable to disable deadline
return nil, err return nil, err
} }
@@ -322,7 +322,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
if opErr == nil { if opErr == nil {
err = nil err = nil
} }
case rc.c.doneChan <- struct{}{}: case rc.doneChan <- struct{}{}:
err = opErr err = opErr
} }
@@ -330,34 +330,34 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
} }
func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
rc.c.lastActivityTime = time.Now() rc.lastActivityTime = time.Now()
rows := rc.c.getRows(sql, nil) rows := rc.getRows(sql, nil)
if err := rc.c.lock(); err != nil { if err := rc.lock(); err != nil {
rows.fatal(err) rows.fatal(err)
return rows, err return rows, err
} }
rows.unlockConn = true rows.unlockConn = true
err := rc.c.sendSimpleQuery(sql) err := rc.sendSimpleQuery(sql)
if err != nil { if err != nil {
rows.fatal(err) rows.fatal(err)
} }
msg, err := rc.c.rxMsg() msg, err := rc.rxMsg()
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.RowDescription: case *pgproto3.RowDescription:
rows.fields = rc.c.rxRowDescription(msg) rows.fields = rc.rxRowDescription(msg)
// We don't have c.PgTypes here because we're a replication // We don't have c.PgTypes here because we're a replication
// connection. This means the field descriptions will have // connection. This means the field descriptions will have
// only OIDs. Not much we can do about this. // only OIDs. Not much we can do about this.
default: default:
if e := rc.c.processContextFreeMsg(msg); e != nil { if e := rc.processContextFreeMsg(msg); e != nil {
rows.fatal(e) rows.fatal(e)
return rows, e return rows, e
} }
@@ -417,7 +417,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
queryString += fmt.Sprintf(" ( %s )", strings.Join(pluginArguments, ", ")) queryString += fmt.Sprintf(" ( %s )", strings.Join(pluginArguments, ", "))
} }
if err = rc.c.sendQuery(queryString); err != nil { if err = rc.sendQuery(queryString); err != nil {
return return
} }
@@ -431,8 +431,8 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
var r *ReplicationMessage var r *ReplicationMessage
r, err = rc.WaitForReplicationMessage(ctx) r, err = rc.WaitForReplicationMessage(ctx)
if err != nil && r != nil { if err != nil && r != nil {
if rc.c.shouldLog(LogLevelError) { if rc.shouldLog(LogLevelError) {
rc.c.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err}) rc.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err})
} }
} }
@@ -441,7 +441,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
// Create the replication slot, using the given name and output plugin. // Create the replication slot, using the given name and output plugin.
func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) { func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) {
_, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin)) _, err = rc.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
return return
} }
@@ -459,6 +459,6 @@ func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string
// Drop the replication slot for the given name // Drop the replication slot for the given name
func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) { func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) {
_, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) _, err = rc.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
return return
} }