Allow normal queries on replication connections
The replication connection allows executing most of the SQL queries which are available on non-replication connections.
This commit is contained in:
+38
-38
@@ -168,18 +168,18 @@ func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return &ReplicationConn{c: c}, nil
|
||||
return &ReplicationConn{c}, nil
|
||||
}
|
||||
|
||||
type ReplicationConn struct {
|
||||
c *Conn
|
||||
*Conn
|
||||
}
|
||||
|
||||
// Send standby status to the server, which both acts as a keepalive
|
||||
// message to the server, as well as carries the WAL position of the
|
||||
// client, which then updates the server's replication slot position.
|
||||
func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
|
||||
buf := rc.c.wbuf
|
||||
buf := rc.wbuf
|
||||
buf = append(buf, copyData)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
@@ -193,46 +193,46 @@ func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
|
||||
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err = rc.c.conn.Write(buf)
|
||||
_, err = rc.conn.Write(buf)
|
||||
if err != nil {
|
||||
rc.c.die(err)
|
||||
rc.die(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (rc *ReplicationConn) Close() error {
|
||||
return rc.c.Close()
|
||||
return rc.Close()
|
||||
}
|
||||
|
||||
func (rc *ReplicationConn) IsAlive() bool {
|
||||
return rc.c.IsAlive()
|
||||
return rc.IsAlive()
|
||||
}
|
||||
|
||||
func (rc *ReplicationConn) CauseOfDeath() error {
|
||||
return rc.c.CauseOfDeath()
|
||||
return rc.CauseOfDeath()
|
||||
}
|
||||
|
||||
func (rc *ReplicationConn) GetConnInfo() *pgtype.ConnInfo {
|
||||
return rc.c.ConnInfo
|
||||
return rc.ConnInfo
|
||||
}
|
||||
|
||||
func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) {
|
||||
msg, err := rc.c.rxMsg()
|
||||
msg, err := rc.rxMsg()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.NoticeResponse:
|
||||
pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg))
|
||||
if rc.c.shouldLog(LogLevelInfo) {
|
||||
rc.c.log(LogLevelInfo, pgError.Error(), nil)
|
||||
pgError := rc.rxErrorResponse((*pgproto3.ErrorResponse)(msg))
|
||||
if rc.shouldLog(LogLevelInfo) {
|
||||
rc.log(LogLevelInfo, pgError.Error(), nil)
|
||||
}
|
||||
case *pgproto3.ErrorResponse:
|
||||
err = rc.c.rxErrorResponse(msg)
|
||||
if rc.c.shouldLog(LogLevelError) {
|
||||
rc.c.log(LogLevelError, err.Error(), nil)
|
||||
err = rc.rxErrorResponse(msg)
|
||||
if rc.shouldLog(LogLevelError) {
|
||||
rc.log(LogLevelError, err.Error(), nil)
|
||||
}
|
||||
return
|
||||
case *pgproto3.CopyBothResponse:
|
||||
@@ -269,13 +269,13 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err
|
||||
h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow}
|
||||
return &ReplicationMessage{ServerHeartbeat: h}, nil
|
||||
default:
|
||||
if rc.c.shouldLog(LogLevelError) {
|
||||
rc.c.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType})
|
||||
if rc.shouldLog(LogLevelError) {
|
||||
rc.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType})
|
||||
}
|
||||
}
|
||||
default:
|
||||
if rc.c.shouldLog(LogLevelError) {
|
||||
rc.c.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg})
|
||||
if rc.shouldLog(LogLevelError) {
|
||||
rc.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg})
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -300,12 +300,12 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := rc.c.conn.SetDeadline(time.Now()); err != nil {
|
||||
if err := rc.conn.SetDeadline(time.Now()); err != nil {
|
||||
rc.Close() // Close connection if unable to set deadline
|
||||
return
|
||||
}
|
||||
rc.c.closedChan <- ctx.Err()
|
||||
case <-rc.c.doneChan:
|
||||
rc.closedChan <- ctx.Err()
|
||||
case <-rc.doneChan:
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -313,8 +313,8 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-rc.c.closedChan:
|
||||
if err := rc.c.conn.SetDeadline(time.Time{}); err != nil {
|
||||
case err = <-rc.closedChan:
|
||||
if err := rc.conn.SetDeadline(time.Time{}); err != nil {
|
||||
rc.Close() // Close connection if unable to disable deadline
|
||||
return nil, err
|
||||
}
|
||||
@@ -322,7 +322,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
|
||||
if opErr == nil {
|
||||
err = nil
|
||||
}
|
||||
case rc.c.doneChan <- struct{}{}:
|
||||
case rc.doneChan <- struct{}{}:
|
||||
err = opErr
|
||||
}
|
||||
|
||||
@@ -330,34 +330,34 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl
|
||||
}
|
||||
|
||||
func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
|
||||
rc.c.lastActivityTime = time.Now()
|
||||
rc.lastActivityTime = time.Now()
|
||||
|
||||
rows := rc.c.getRows(sql, nil)
|
||||
rows := rc.getRows(sql, nil)
|
||||
|
||||
if err := rc.c.lock(); err != nil {
|
||||
if err := rc.lock(); err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
rows.unlockConn = true
|
||||
|
||||
err := rc.c.sendSimpleQuery(sql)
|
||||
err := rc.sendSimpleQuery(sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
}
|
||||
|
||||
msg, err := rc.c.rxMsg()
|
||||
msg, err := rc.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.RowDescription:
|
||||
rows.fields = rc.c.rxRowDescription(msg)
|
||||
rows.fields = rc.rxRowDescription(msg)
|
||||
// We don't have c.PgTypes here because we're a replication
|
||||
// connection. This means the field descriptions will have
|
||||
// only OIDs. Not much we can do about this.
|
||||
default:
|
||||
if e := rc.c.processContextFreeMsg(msg); e != nil {
|
||||
if e := rc.processContextFreeMsg(msg); e != nil {
|
||||
rows.fatal(e)
|
||||
return rows, e
|
||||
}
|
||||
@@ -417,7 +417,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
|
||||
queryString += fmt.Sprintf(" ( %s )", strings.Join(pluginArguments, ", "))
|
||||
}
|
||||
|
||||
if err = rc.c.sendQuery(queryString); err != nil {
|
||||
if err = rc.sendQuery(queryString); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -431,8 +431,8 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
|
||||
var r *ReplicationMessage
|
||||
r, err = rc.WaitForReplicationMessage(ctx)
|
||||
if err != nil && r != nil {
|
||||
if rc.c.shouldLog(LogLevelError) {
|
||||
rc.c.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err})
|
||||
if rc.shouldLog(LogLevelError) {
|
||||
rc.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,7 +441,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti
|
||||
|
||||
// Create the replication slot, using the given name and output plugin.
|
||||
func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) {
|
||||
_, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
|
||||
_, err = rc.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -459,6 +459,6 @@ func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string
|
||||
|
||||
// Drop the replication slot for the given name
|
||||
func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) {
|
||||
_, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
|
||||
_, err = rc.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user