diff --git a/replication.go b/replication.go index 9bc4a1a4..a251172d 100644 --- a/replication.go +++ b/replication.go @@ -270,16 +270,43 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err // // This returns the context error when there is no replication message before // the context is canceled. -func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (r *ReplicationMessage, err error) { - err = rc.c.initContext(ctx) - if err != nil { - return nil, err +func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*ReplicationMessage, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: } - defer func() { - err = rc.c.termContext(err) + + go func() { + select { + case <-ctx.Done(): + if err := rc.c.conn.SetDeadline(time.Now()); err != nil { + rc.Close() // Close connection if unable to set deadline + return + } + rc.c.closedChan <- ctx.Err() + case <-rc.c.doneChan: + } }() - return rc.readReplicationMessage() + r, opErr := rc.readReplicationMessage() + + var err error + select { + case err = <-rc.c.closedChan: + if err := rc.c.conn.SetDeadline(time.Time{}); err != nil { + rc.Close() // Close connection if unable to disable deadline + return nil, err + } + + if opErr == nil { + err = nil + } + case rc.c.doneChan <- struct{}{}: + err = opErr + } + + return r, err } func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { diff --git a/replication_test.go b/replication_test.go index 43793f3c..1a8063e5 100644 --- a/replication_test.go +++ b/replication_test.go @@ -3,12 +3,13 @@ package pgx_test import ( "context" "fmt" - "github.com/jackc/pgx" "reflect" "strconv" "strings" "testing" "time" + + "github.com/jackc/pgx" ) // This function uses a postgresql 9.6 specific column @@ -47,14 +48,19 @@ func TestSimpleReplicationConnection(t *testing.T) { } conn := mustConnect(t, *replicationConnConfig) - defer closeConn(t, conn) + defer func() { + // Ensure replication slot is destroyed, but don't check for errors as it + // should have already been destroyed. + conn.Exec("select pg_drop_replication_slot('pgx_test')") + closeConn(t, conn) + }() replicationConn := mustReplicationConnect(t, *replicationConnConfig) defer closeReplicationConn(t, replicationConn) err = replicationConn.CreateReplicationSlot("pgx_test", "test_decoding") if err != nil { - t.Logf("replication slot create failed: %v", err) + t.Fatalf("replication slot create failed: %v", err) } // Do a simple change so we can get some wal data