diff --git a/.travis.yml b/.travis.yml index 32b35bbd..d9ea43b0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: - - 1.7.1 - - 1.6.3 + - 1.7.4 + - 1.6.4 - tip # Derived from https://github.com/lib/pq/blob/master/.travis.yml @@ -34,7 +34,6 @@ env: - PGVERSION=9.4 - PGVERSION=9.3 - PGVERSION=9.2 - - PGVERSION=9.1 # The tricky test user, below, has to actually exist so that it can be used in a test # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. diff --git a/CHANGELOG.md b/CHANGELOG.md index bedf106b..126baef4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Fixes * Oid underlying type changed to uint32, previously it was incorrectly int32 (Manni Wood) +* Explicitly close checked-in connections on ConnPool.Reset, previously they were closed by GC ## Features @@ -13,6 +14,8 @@ * Add NullOid type (Manni Wood) * Add json/jsonb binary support to allow use with CopyTo * Add named error ErrAcquireTimeout (Alexander Staubo) +* Add logical replication decoding (Kris Wehner) +* Add PgxScanner interface to allow types to simultaneously support database/sql and pgx (Jack Christensen) ## Compatibility diff --git a/conn_pool.go b/conn_pool.go index 4bb64a24..6614c4f0 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -255,8 +255,13 @@ func (p *ConnPool) Reset() { defer p.cond.L.Unlock() p.resetCount++ - p.allConnections = make([]*Conn, 0, p.maxConnections) - p.availableConnections = make([]*Conn, 0, p.maxConnections) + p.allConnections = p.allConnections[0:0] + + for _, conn := range p.availableConnections { + conn.Close() + } + + p.availableConnections = p.availableConnections[0:0] } // invalidateAcquired causes all acquired connections to be closed when released. diff --git a/conn_pool_test.go b/conn_pool_test.go index 0bbda0bc..9d03fad3 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -465,32 +465,38 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { } } -func TestConnPoolReset(t *testing.T) { +func TestConnPoolResetClosesCheckedOutConnectionsOnRelease(t *testing.T) { t.Parallel() pool := createConnPool(t, 5) defer pool.Close() inProgressRows := []*pgx.Rows{} + var inProgressPIDs []int32 // Start some queries and reset pool while they are in progress for i := 0; i < 10; i++ { - rows, err := pool.Query("select generate_series(1,5)::bigint") + rows, err := pool.Query("select pg_backend_pid() union all select 1 union all select 2") if err != nil { t.Fatal(err) } + rows.Next() + var pid int32 + rows.Scan(&pid) + inProgressPIDs = append(inProgressPIDs, pid) + inProgressRows = append(inProgressRows, rows) pool.Reset() } // Check that the queries are completed for _, rows := range inProgressRows { - var expectedN int64 + var expectedN int32 for rows.Next() { expectedN++ - var n int64 + var n int32 err := rows.Scan(&n) if err != nil { t.Fatal(err) @@ -510,6 +516,75 @@ func TestConnPoolReset(t *testing.T) { if stats.CurrentConnections != 0 || stats.AvailableConnections != 0 { t.Fatalf("Unexpected connection pool stats: %v", stats) } + + var connCount int + err := pool.QueryRow("select count(*) from pg_stat_activity where pid = any($1::int4[])", inProgressPIDs).Scan(&connCount) + if err != nil { + t.Fatal(err) + } + if connCount != 0 { + t.Fatalf("%d connections not closed", connCount) + } +} + +func TestConnPoolResetClosesCheckedInConnections(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 5) + defer pool.Close() + + inProgressRows := []*pgx.Rows{} + var inProgressPIDs []int32 + + // Start some queries and reset pool while they are in progress + for i := 0; i < 5; i++ { + rows, err := pool.Query("select pg_backend_pid()") + if err != nil { + t.Fatal(err) + } + + inProgressRows = append(inProgressRows, rows) + } + + // Check that the queries are completed + for _, rows := range inProgressRows { + for rows.Next() { + var pid int32 + err := rows.Scan(&pid) + if err != nil { + t.Fatal(err) + } + inProgressPIDs = append(inProgressPIDs, pid) + + } + + if err := rows.Err(); err != nil { + t.Fatal(err) + } + } + + // Ensure pool is fully connected and available + stats := pool.Stat() + if stats.CurrentConnections != 5 || stats.AvailableConnections != 5 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + + pool.Reset() + + // Pool should be empty after reset + stats = pool.Stat() + if stats.CurrentConnections != 0 || stats.AvailableConnections != 0 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + + var connCount int + err := pool.QueryRow("select count(*) from pg_stat_activity where pid = any($1::int4[])", inProgressPIDs).Scan(&connCount) + if err != nil { + t.Fatal(err) + } + if connCount != 0 { + t.Fatalf("%d connections not closed", connCount) + } } func TestConnPoolTransaction(t *testing.T) { diff --git a/doc.go b/doc.go index 14843c28..514d51a0 100644 --- a/doc.go +++ b/doc.go @@ -157,14 +157,15 @@ Custom Type Support pgx includes support for the common data types like integers, floats, strings, dates, and times that have direct mappings between Go and SQL. Support can be added for additional types like point, hstore, numeric, etc. that do not have -direct mappings in Go by the types implementing Scanner and Encoder. +direct mappings in Go by the types implementing ScannerPgx and Encoder. Custom types can support text or binary formats. Binary format can provide a large performance increase. The natural place for deciding the format for a -value would be in Scanner as it is responsible for decoding the returned data. -However, that is impossible as the query has already been sent by the time the -Scanner is invoked. The solution to this is the global DefaultTypeFormats. If a -custom type prefers binary format it should register it there. +value would be in ScannerPgx as it is responsible for decoding the returned +data. However, that is impossible as the query has already been sent by the time +the ScannerPgx is invoked. The solution to this is the global +DefaultTypeFormats. If a custom type prefers binary format it should register it +there. pgx.DefaultTypeFormats["point"] = pgx.BinaryFormatCode diff --git a/example_custom_type_test.go b/example_custom_type_test.go index ddf4732d..674466f3 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -18,7 +18,7 @@ type NullPoint struct { Valid bool // Valid is true if not NULL } -func (p *NullPoint) Scan(vr *pgx.ValueReader) error { +func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error { if vr.Type().DataTypeName != "point" { return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (OID %d)", vr.Type().DataTypeName, vr.Type().DataType)) } diff --git a/helper_test.go b/helper_test.go index ed5a9644..eff731e8 100644 --- a/helper_test.go +++ b/helper_test.go @@ -13,6 +13,15 @@ func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn { return conn } +func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.ReplicationConn { + conn, err := pgx.ReplicationConnect(config) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + return conn +} + + func closeConn(t testing.TB, conn *pgx.Conn) { err := conn.Close() if err != nil { @@ -20,6 +29,13 @@ func closeConn(t testing.TB, conn *pgx.Conn) { } } +func closeReplicationConn(t testing.TB, conn *pgx.ReplicationConn) { + err := conn.Close() + if err != nil { + t.Fatalf("conn.Close unexpectedly failed: %v", err) + } +} + func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgx.CommandTag) { var err error if commandTag, err = conn.Exec(sql, arguments...); err != nil { diff --git a/msg_reader.go b/msg_reader.go index 43e80d98..0c3c23b8 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -21,7 +21,7 @@ func (r *msgReader) Err() error { return r.err } -// fatal tells r that a Fatal error has occurred +// fatal tells rc that a Fatal error has occurred func (r *msgReader) fatal(err error) { if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining) diff --git a/query.go b/query.go index 30e0476e..778bc9cc 100644 --- a/query.go +++ b/query.go @@ -264,6 +264,11 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } + } else if s, ok := d.(PgxScanner); ok { + err = s.ScanPgx(vr) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } } else if s, ok := d.(sql.Scanner); ok { var val interface{} if 0 <= vr.Len() { diff --git a/query_test.go b/query_test.go index 791b65cc..f5250589 100644 --- a/query_test.go +++ b/query_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "bytes" "database/sql" + "fmt" "strings" "testing" "time" @@ -291,6 +292,67 @@ func TestConnQueryScanner(t *testing.T) { ensureConnValid(t, conn) } +type pgxNullInt64 struct { + Int64 int64 + Valid bool // Valid is true if Int64 is not NULL +} + +func (n *pgxNullInt64) ScanPgx(vr *pgx.ValueReader) error { + if vr.Type().DataType != pgx.Int8Oid { + return pgx.SerializationError(fmt.Sprintf("pgxNullInt64.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Int64, n.Valid = 0, false + return nil + } + n.Valid = true + + err := pgx.Decode(vr, &n.Int64) + if err != nil { + return err + } + return vr.Err() +} + +func TestConnQueryPgxScanner(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select null::int8, 1::int8") + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var n, m pgxNullInt64 + err = rows.Scan(&n, &m) + if err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + rows.Close() + + if n.Valid { + t.Error("Null should not be valid, but it was") + } + + if !m.Valid { + t.Error("1 should be valid, but it wasn't") + } + + if m.Int64 != 1 { + t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) + } + + ensureConnValid(t, conn) +} + func TestConnQueryErrorWhileReturningRows(t *testing.T) { t.Parallel() diff --git a/replication.go b/replication.go index 7d4c56e2..7b28d6b6 100644 --- a/replication.go +++ b/replication.go @@ -8,10 +8,10 @@ import ( ) const ( - copyBothResponse = 'W' - walData = 'w' - senderKeepalive = 'k' - standbyStatusUpdate = 'r' + copyBothResponse = 'W' + walData = 'w' + senderKeepalive = 'k' + standbyStatusUpdate = 'r' initialReplicationResponseTimeout = 5 * time.Second ) @@ -151,11 +151,28 @@ func NewStandbyStatus(walPositions ...uint64) (status *StandbyStatus, err error) return } +func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) { + if config.RuntimeParams == nil { + config.RuntimeParams = make(map[string]string) + } + config.RuntimeParams["replication"] = "database" + + c, err := Connect(config) + if err != nil { + return + } + return &ReplicationConn{c: c}, nil +} + +type ReplicationConn struct { + c *Conn +} + // Send standby status to the server, which both acts as a keepalive // message to the server, as well as carries the WAL position of the // client, which then updates the server's replication slot position. -func (c *Conn) SendStandbyStatus(k *StandbyStatus) (err error) { - writeBuf := newWriteBuf(c, copyData) +func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { + writeBuf := newWriteBuf(rc.c, copyData) writeBuf.WriteByte(standbyStatusUpdate) writeBuf.WriteInt64(int64(k.WalWritePosition)) writeBuf.WriteInt64(int64(k.WalFlushPosition)) @@ -165,47 +182,44 @@ func (c *Conn) SendStandbyStatus(k *StandbyStatus) (err error) { writeBuf.closeMsg() - _, err = c.conn.Write(writeBuf.buf) + _, err = rc.c.conn.Write(writeBuf.buf) if err != nil { - c.die(err) + rc.c.die(err) } return } -// Send the message to formally stop the replication stream. This -// is done before calling Close() during a clean shutdown. -func (c *Conn) StopReplication() (err error) { - writeBuf := newWriteBuf(c, copyDone) - - writeBuf.closeMsg() - - _, err = c.conn.Write(writeBuf.buf) - if err != nil { - c.die(err) - } - return +func (rc *ReplicationConn) Close() error { + return rc.c.Close() } +func (rc *ReplicationConn) IsAlive() bool { + return rc.c.IsAlive() +} -func (c *Conn) readReplicationMessage() (r *ReplicationMessage, err error) { +func (rc *ReplicationConn) CauseOfDeath() error { + return rc.c.CauseOfDeath() +} + +func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { var t byte var reader *msgReader - t, reader, err = c.rxMsg() + t, reader, err = rc.c.rxMsg() if err != nil { return } switch t { case noticeResponse: - pgError := c.rxErrorResponse(reader) - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, pgError.Error()) + pgError := rc.c.rxErrorResponse(reader) + if rc.c.shouldLog(LogLevelInfo) { + rc.c.log(LogLevelInfo, pgError.Error()) } case errorResponse: - err = c.rxErrorResponse(reader) - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, err.Error()) + err = rc.c.rxErrorResponse(reader) + if rc.c.shouldLog(LogLevelError) { + rc.c.log(LogLevelError, err.Error()) } return case copyBothResponse: @@ -235,13 +249,13 @@ func (c *Conn) readReplicationMessage() (r *ReplicationMessage, err error) { h := &ServerHeartbeat{ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), ReplyRequested: replyNow} return &ReplicationMessage{ServerHeartbeat: h}, nil default: - if c.shouldLog(LogLevelError) { - c.log(LogLevelError,"Unexpected data playload message type %v", t) + if rc.c.shouldLog(LogLevelError) { + rc.c.log(LogLevelError, "Unexpected data playload message type %v", t) } } default: - if c.shouldLog(LogLevelError) { - c.log(LogLevelError,"Unexpected replication message type %v", t) + if rc.c.shouldLog(LogLevelError) { + rc.c.log(LogLevelError, "Unexpected replication message type %v", t) } } return @@ -256,7 +270,7 @@ func (c *Conn) readReplicationMessage() (r *ReplicationMessage, err error) { // // This returns pgx.ErrNotificationTimeout when there is no replication message by the specified // duration. -func (c *Conn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) { +func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) { var zeroTime time.Time deadline := time.Now().Add(timeout) @@ -269,27 +283,95 @@ func (c *Conn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationM // deadline and peek into the reader. If a timeout error occurs there // we don't break the pgx connection. If the Peek returns that data // is available then we turn off the read deadline before the rxMsg. - err = c.conn.SetReadDeadline(deadline) + err = rc.c.conn.SetReadDeadline(deadline) if err != nil { return nil, err } // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = c.reader.Peek(1) + _, err = rc.c.reader.Peek(1) if err != nil { - c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline + rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline if err, ok := err.(*net.OpError); ok && err.Timeout() { return nil, ErrNotificationTimeout } return nil, err } - err = c.conn.SetReadDeadline(zeroTime) + err = rc.c.conn.SetReadDeadline(zeroTime) if err != nil { return nil, err } - return c.readReplicationMessage() + return rc.readReplicationMessage() +} + +func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { + rc.c.lastActivityTime = time.Now() + + rows := rc.c.getRows(sql, nil) + + if err := rc.c.lock(); err != nil { + rows.abort(err) + return rows, err + } + rows.unlockConn = true + + err := rc.c.sendSimpleQuery(sql) + if err != nil { + rows.abort(err) + } + + var t byte + var r *msgReader + t, r, err = rc.c.rxMsg() + if err != nil { + return nil, err + } + + switch t { + case rowDescription: + rows.fields = rc.c.rxRowDescription(r) + // We don't have c.PgTypes here because we're a replication + // connection. This means the field descriptions will have + // only Oids. Not much we can do about this. + default: + if e := rc.c.processContextFreeMsg(t, r); e != nil { + rows.abort(e) + return rows, e + } + } + + return rows, rows.err +} + +// Execute the "IDENTIFY_SYSTEM" command as documented here: +// https://www.postgresql.org/docs/9.5/static/protocol-replication.html +// +// This will return (if successful) a result set that has a single row +// that contains the systemid, current timeline, xlogpos and database +// name. +// +// NOTE: Because this is a replication mode connection, we don't have +// type names, so the field descriptions in the result will have only +// Oids and no DataTypeName values +func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) { + return rc.sendReplicationModeQuery("IDENTIFY_SYSTEM") +} + +// Execute the "TIMELINE_HISTORY" command as documented here: +// https://www.postgresql.org/docs/9.5/static/protocol-replication.html +// +// This will return (if successful) a result set that has a single row +// that contains the filename of the history file and the content +// of the history file. If called for timeline 1, typically this will +// generate an error that the timeline history file does not exist. +// +// NOTE: Because this is a replication mode connection, we don't have +// type names, so the field descriptions in the result will have only +// Oids and no DataTypeName values +func (rc *ReplicationConn) TimelineHistory(timeline int) (r *Rows, err error) { + return rc.sendReplicationModeQuery(fmt.Sprintf("TIMELINE_HISTORY %d", timeline)) } // Start a replication connection, sending WAL data to the given replication @@ -303,7 +385,7 @@ func (c *Conn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationM // // This function assumes that slotName has already been created. In order to omit the timeline argument // pass a -1 for the timeline to get the server default behavior. -func (c *Conn) StartReplication(slotName string, startLsn uint64, timeline int64, pluginArguments ...string) (err error) { +func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, timeline int64, pluginArguments ...string) (err error) { var queryString string if timeline >= 0 { queryString = fmt.Sprintf("START_REPLICATION SLOT %s LOGICAL %s TIMELINE %d", slotName, FormatLSN(startLsn), timeline) @@ -315,7 +397,7 @@ func (c *Conn) StartReplication(slotName string, startLsn uint64, timeline int64 queryString += fmt.Sprintf(" %s", arg) } - if err = c.sendQuery(queryString); err != nil { + if err = rc.c.sendQuery(queryString); err != nil { return } @@ -324,12 +406,24 @@ func (c *Conn) StartReplication(slotName string, startLsn uint64, timeline int64 // started. This call will either return nil, nil or if it returns an error // that indicates the start replication command failed var r *ReplicationMessage - r, err = c.WaitForReplicationMessage(initialReplicationResponseTimeout) + r, err = rc.WaitForReplicationMessage(initialReplicationResponseTimeout) if err != nil && r != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "Unxpected replication message %v", r) + if rc.c.shouldLog(LogLevelError) { + rc.c.log(LogLevelError, "Unxpected replication message %v", r) } } return } + +// Create the replication slot, using the given name and output plugin. +func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) { + _, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s", slotName, outputPlugin)) + return +} + +// Drop the replication slot for the given name +func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) { + _, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) + return +} diff --git a/replication_test.go b/replication_test.go index 866fe45e..4f810c78 100644 --- a/replication_test.go +++ b/replication_test.go @@ -1,13 +1,13 @@ package pgx_test import ( + "fmt" "github.com/jackc/pgx" + "reflect" "strconv" "strings" "testing" "time" - "reflect" - "fmt" ) // This function uses a postgresql 9.6 specific column @@ -48,13 +48,10 @@ func TestSimpleReplicationConnection(t *testing.T) { conn := mustConnect(t, *replicationConnConfig) defer closeConn(t, conn) - replicationConnConfig.RuntimeParams = make(map[string]string) - replicationConnConfig.RuntimeParams["replication"] = "database" + replicationConn := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn) - replicationConn := mustConnect(t, *replicationConnConfig) - defer closeConn(t, replicationConn) - - _, err = replicationConn.Exec("CREATE_REPLICATION_SLOT pgx_test LOGICAL test_decoding") + err = replicationConn.CreateReplicationSlot("pgx_test", "test_decoding") if err != nil { t.Logf("replication slot create failed: %v", err) } @@ -153,16 +150,170 @@ func TestSimpleReplicationConnection(t *testing.T) { t.Errorf("Failed to create standby status %v", err) } replicationConn.SendStandbyStatus(status) - replicationConn.StopReplication() + restartLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test") + integerRestartLsn, _ := pgx.ParseLSN(restartLsn) + if integerRestartLsn != maxWal { + t.Fatalf("Wal offset update failed, expected %s found %s", pgx.FormatLSN(maxWal), restartLsn) + } + + closeReplicationConn(t, replicationConn) + + replicationConn2 := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn2) + + err = replicationConn2.DropReplicationSlot("pgx_test") + if err != nil { + t.Fatalf("Failed to drop replication slot: %v", err) + } + + droppedLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test") + if droppedLsn != "" { + t.Errorf("Got odd flush lsn %s for supposedly dropped slot", droppedLsn) + } +} + +func TestReplicationConn_DropReplicationSlot(t *testing.T) { + if replicationConnConfig == nil { + t.Skip("Skipping due to undefined replicationConnConfig") + } + + replicationConn := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn) + + err := replicationConn.CreateReplicationSlot("pgx_slot_test", "test_decoding") + if err != nil { + t.Logf("replication slot create failed: %v", err) + } + err = replicationConn.DropReplicationSlot("pgx_slot_test") + if err != nil { + t.Fatalf("Failed to drop replication slot: %v", err) + } + + // We re-create to ensure the drop worked. + err = replicationConn.CreateReplicationSlot("pgx_slot_test", "test_decoding") + if err != nil { + t.Logf("replication slot create failed: %v", err) + } + + // And finally we drop to ensure we don't leave dirty state + err = replicationConn.DropReplicationSlot("pgx_slot_test") + if err != nil { + t.Fatalf("Failed to drop replication slot: %v", err) + } +} + +func TestIdentifySystem(t *testing.T) { + if replicationConnConfig == nil { + t.Skip("Skipping due to undefined replicationConnConfig") + } + + replicationConn2 := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn2) + + r, err := replicationConn2.IdentifySystem() + if err != nil { + t.Error(err) + } + defer r.Close() + for _, fd := range r.FieldDescriptions() { + t.Logf("Field: %s of type %v", fd.Name, fd.DataType) + } + + var rowCount int + for r.Next() { + rowCount++ + values, err := r.Values() + if err != nil { + t.Error(err) + } + t.Logf("Row values: %v", values) + } + if r.Err() != nil { + t.Error(r.Err()) + } + + if rowCount == 0 { + t.Errorf("Failed to find any rows: %d", rowCount) + } +} + +func getCurrentTimeline(t *testing.T, rc *pgx.ReplicationConn) int { + r, err := rc.IdentifySystem() + if err != nil { + t.Error(err) + } + defer r.Close() + for r.Next() { + values, e := r.Values() + if e != nil { + t.Error(e) + } + timeline, e := strconv.Atoi(values[1].(string)) + if e != nil { + t.Error(e) + } + return timeline + } + t.Fatal("Failed to read timeline") + return -1 +} + +func TestGetTimelineHistory(t *testing.T) { + if replicationConnConfig == nil { + t.Skip("Skipping due to undefined replicationConnConfig") + } + + replicationConn := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn) + + timeline := getCurrentTimeline(t, replicationConn) + + r, err := replicationConn.TimelineHistory(timeline) + if err != nil { + t.Errorf("%#v", err) + } + defer r.Close() + + for _, fd := range r.FieldDescriptions() { + t.Logf("Field: %s of type %v", fd.Name, fd.DataType) + } + + var rowCount int + for r.Next() { + rowCount++ + values, err := r.Values() + if err != nil { + t.Error(err) + } + t.Logf("Row values: %v", values) + } + if r.Err() != nil { + if strings.Contains(r.Err().Error(), "No such file or directory") { + // This is normal, this means the timeline we're on has no + // history, which is the common case in a test db that + // has only one timeline + return + } + t.Error(r.Err()) + } + + // If we have a timeline history (see above) there should have been + // rows emitted + if rowCount == 0 { + t.Errorf("Failed to find any rows: %d", rowCount) + } +} + +func TestStandbyStatusParsing(t *testing.T) { // Let's push the boundary conditions of the standby status and ensure it errors correctly - status, err = pgx.NewStandbyStatus(0,1,2,3,4) + status, err := pgx.NewStandbyStatus(0, 1, 2, 3, 4) if err == nil { - t.Errorf("Expected error from new standby status, got %v",status) + t.Errorf("Expected error from new standby status, got %v", status) } // And if you provide 3 args, ensure the right fields are set - status, err = pgx.NewStandbyStatus(1,2,3) + status, err = pgx.NewStandbyStatus(1, 2, 3) if err != nil { t.Errorf("Failed to create test status: %v", err) } @@ -175,21 +326,4 @@ func TestSimpleReplicationConnection(t *testing.T) { if status.WalWritePosition != 3 { t.Errorf("Unexpected write position %d", status.WalWritePosition) } - - err = replicationConn.Close() - if err != nil { - t.Fatalf("Replication connection close failed: %v", err) - } - - restartLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test") - integerRestartLsn, _ := pgx.ParseLSN(restartLsn) - if integerRestartLsn != maxWal { - t.Fatalf("Wal offset update failed, expected %s found %s", pgx.FormatLSN(maxWal), restartLsn) - } - - _, err = conn.Exec("select pg_drop_replication_slot($1)", "pgx_test") - if err != nil { - t.Fatalf("Failed to drop replication slot: %v", err) - } - } diff --git a/values.go b/values.go index 231a37f7..a59ca0c3 100644 --- a/values.go +++ b/values.go @@ -127,7 +127,9 @@ func (e SerializationError) Error() string { return string(e) } -// Scanner is an interface used to decode values from the PostgreSQL server. +// Deprecated: Scanner is an interface used to decode values from the PostgreSQL +// server. To allow types to support pgx and database/sql.Scan this interface +// has been deprecated in favor of PgxScanner. type Scanner interface { // Scan MUST check r.Type().DataType (to check by OID) or // r.Type().DataTypeName (to check by name) to ensure that it is scanning an @@ -137,6 +139,18 @@ type Scanner interface { Scan(r *ValueReader) error } +// PgxScanner is an interface used to decode values from the PostgreSQL server. +// It is used exactly the same as the Scanner interface. It simply has renamed +// the method. +type PgxScanner interface { + // ScanPgx MUST check r.Type().DataType (to check by OID) or + // r.Type().DataTypeName (to check by name) to ensure that it is scanning an + // expected column type. It also MUST check r.Type().FormatCode before + // decoding. It should not assume that it was called on a data type or format + // that it understands. + ScanPgx(r *ValueReader) error +} + // Encoder is an interface used to encode values for transmission to the // PostgreSQL server. type Encoder interface {