From 0c7a1fc13eb847a032c706bfc26dcdce922c8495 Mon Sep 17 00:00:00 2001 From: David Bariod Date: Tue, 15 Jan 2019 11:01:18 +0100 Subject: [PATCH 01/70] support binding of []int type to array integer --- pgtype/int4_array.go | 19 +++++++++++++++++++ pgtype/int4_array_test.go | 25 +++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 4e78ce71..86656524 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -23,6 +23,25 @@ func (dst *Int4Array) Set(src interface{}) error { switch value := src.(type) { + case []int: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []int32: if value == nil { *dst = Int4Array{Status: Null} diff --git a/pgtype/int4_array_test.go b/pgtype/int4_array_test.go index 602a3657..f0418600 100644 --- a/pgtype/int4_array_test.go +++ b/pgtype/int4_array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "math" "reflect" "testing" @@ -54,8 +55,9 @@ func TestInt4ArrayTranscode(t *testing.T) { func TestInt4ArraySet(t *testing.T) { successfulTests := []struct { - source interface{} - result pgtype.Int4Array + source interface{} + result pgtype.Int4Array + expectedError bool }{ { source: []int32{1}, @@ -64,6 +66,17 @@ func TestInt4ArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []int{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []int{1, math.MaxInt32 + 1, 2}, + expectedError: true, + }, { source: []uint32{1}, result: pgtype.Int4Array{ @@ -81,9 +94,17 @@ func TestInt4ArraySet(t *testing.T) { var r pgtype.Int4Array err := r.Set(tt.source) if err != nil { + if tt.expectedError { + continue + } t.Errorf("%d: %v", i, err) } + if tt.expectedError { + t.Errorf("%d: an error was expected, %v", i, tt) + continue + } + if !reflect.DeepEqual(r, tt.result) { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } From 3f1d975e4b35a6c8fe6e04d26c768f7b2ba9d1cc Mon Sep 17 00:00:00 2001 From: Josh Leverette Date: Thu, 17 Jan 2019 22:19:08 -0800 Subject: [PATCH 02/70] Fix encoding of ErrorResponse --- pgproto3/error_response.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go index 160234f2..987fe38a 100644 --- a/pgproto3/error_response.go +++ b/pgproto3/error_response.go @@ -115,70 +115,87 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { buf.Write(bigEndian.Uint32(0)) if src.Severity != "" { + buf.WriteByte('S') buf.WriteString(src.Severity) buf.WriteByte(0) } if src.Code != "" { + buf.WriteByte('C') buf.WriteString(src.Code) buf.WriteByte(0) } if src.Message != "" { + buf.WriteByte('M') buf.WriteString(src.Message) buf.WriteByte(0) } if src.Detail != "" { + buf.WriteByte('D') buf.WriteString(src.Detail) buf.WriteByte(0) } if src.Hint != "" { + buf.WriteByte('H') buf.WriteString(src.Hint) buf.WriteByte(0) } if src.Position != 0 { + buf.WriteByte('P') buf.WriteString(strconv.Itoa(int(src.Position))) buf.WriteByte(0) } if src.InternalPosition != 0 { + buf.WriteByte('p') buf.WriteString(strconv.Itoa(int(src.InternalPosition))) buf.WriteByte(0) } if src.InternalQuery != "" { + buf.WriteByte('q') buf.WriteString(src.InternalQuery) buf.WriteByte(0) } if src.Where != "" { + buf.WriteByte('W') buf.WriteString(src.Where) buf.WriteByte(0) } if src.SchemaName != "" { + buf.WriteByte('s') buf.WriteString(src.SchemaName) buf.WriteByte(0) } if src.TableName != "" { + buf.WriteByte('t') buf.WriteString(src.TableName) buf.WriteByte(0) } if src.ColumnName != "" { + buf.WriteByte('c') buf.WriteString(src.ColumnName) buf.WriteByte(0) } if src.DataTypeName != "" { + buf.WriteByte('d') buf.WriteString(src.DataTypeName) buf.WriteByte(0) } if src.ConstraintName != "" { + buf.WriteByte('n') buf.WriteString(src.ConstraintName) buf.WriteByte(0) } if src.File != "" { + buf.WriteByte('F') buf.WriteString(src.File) buf.WriteByte(0) } if src.Line != 0 { + buf.WriteByte('L') buf.WriteString(strconv.Itoa(int(src.Line))) buf.WriteByte(0) } if src.Routine != "" { + buf.WriteByte('R') buf.WriteString(src.Routine) buf.WriteByte(0) } From a48ad29c160ba65d6cf8532510b25f0b9ae765e6 Mon Sep 17 00:00:00 2001 From: Ilya Sivanev Date: Mon, 21 Jan 2019 17:49:11 +0300 Subject: [PATCH 03/70] Use more detailed error output of unknown field; --- conn.go | 3 ++- query.go | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 968c9253..7b3db53d 100644 --- a/conn.go +++ b/conn.go @@ -1133,7 +1133,8 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared ps.FieldDescriptions[i].FormatCode = TextFormatCode } } else { - return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) + fd := ps.FieldDescriptions[i] + return nil, errors.Errorf("unknown oid: %d, name: %s", fd.DataType, fd.Name) } } case *pgproto3.ReadyForQuery: diff --git a/query.go b/query.go index ad3ed84b..d85ee771 100644 --- a/query.go +++ b/query.go @@ -137,7 +137,8 @@ func (rows *Rows) Next() bool { rows.fields[i].DataTypeName = dt.Name rows.fields[i].FormatCode = TextFormatCode } else { - rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType)) + fd := rows.fields[i] + rows.fatal(errors.Errorf("unknown oid: %d, name: %s", fd.DataType, fd.Name)) return false } } @@ -259,7 +260,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } } } else { - rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)}) + rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v, name: %s", fd.DataType, fd.Name)}) } } @@ -507,7 +508,8 @@ func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { if dt, ok := c.ConnInfo.DataTypeForOID(fieldDescriptions[i].DataType); ok { fieldDescriptions[i].DataTypeName = dt.Name } else { - return nil, errors.Errorf("unknown oid: %d", fieldDescriptions[i].DataType) + fd := fieldDescriptions[i] + return nil, errors.Errorf("unknown oid: %d, name: %s", fd.DataType, fd.Name) } } return fieldDescriptions, nil From 6067cfab4f674940e745d14579632430c2c33b32 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 28 Jan 2019 22:45:44 -0600 Subject: [PATCH 04/70] All Write errors are fatal With TLS connections a Write timeout caused by a SetDeadline permanently breaks the connection. However, the errors are reported as temporary. So there is no way to determine if it really is recoverable. As these were the only kind of Write error that was recovered all Write errors are now fatal to the connection. https://github.com/jackc/pgx/issues/494 https://github.com/jackc/pgx/issues/506 https://github.com/golang/go/issues/29971 --- batch.go | 6 ++---- conn.go | 27 ++++++--------------------- query.go | 4 ++-- 3 files changed, 10 insertions(+), 27 deletions(-) diff --git a/batch.go b/batch.go index 0d7f14cc..4b624387 100644 --- a/batch.go +++ b/batch.go @@ -133,11 +133,9 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { b.conn.pendingReadyForQueryCount++ } - n, err := b.conn.conn.Write(buf) + _, err = b.conn.conn.Write(buf) if err != nil { - if fatalWriteErr(n, err) { - b.conn.die(err) - } + b.conn.die(err) return err } diff --git a/conn.go b/conn.go index 7b3db53d..f8016220 100644 --- a/conn.go +++ b/conn.go @@ -1096,11 +1096,9 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared buf = appendDescribe(buf, 'S', name) buf = appendSync(buf) - n, err := c.conn.Write(buf) + _, err = c.conn.Write(buf) if err != nil { - if fatalWriteErr(n, err) { - c.die(err) - } + c.die(err) return nil, err } c.pendingReadyForQueryCount++ @@ -1360,11 +1358,9 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} buf = appendExecute(buf, "", 0) buf = appendSync(buf) - n, err := c.conn.Write(buf) + _, err = c.conn.Write(buf) if err != nil { - if fatalWriteErr(n, err) { - c.die(err) - } + c.die(err) return err } c.pendingReadyForQueryCount++ @@ -1372,17 +1368,6 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} return nil } -// fatalWriteError takes the response of a net.Conn.Write and determines if it is fatal -func fatalWriteErr(bytesWritten int, err error) bool { - // Partial writes break the connection - if bytesWritten > 0 { - return true - } - - netErr, is := err.(net.Error) - return !(is && netErr.Timeout()) -} - // Exec executes sql. sql can be either a prepared statement name or an SQL string. // arguments should be referenced positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { @@ -1791,8 +1776,8 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) c.lastStmtSent = true - n, err := c.conn.Write(buf) - if err != nil && fatalWriteErr(n, err) { + _, err = c.conn.Write(buf) + if err != nil { c.die(err) return "", err } diff --git a/query.go b/query.go index d85ee771..27969be9 100644 --- a/query.go +++ b/query.go @@ -418,8 +418,8 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) c.lastStmtSent = true - n, err := c.conn.Write(buf) - if err != nil && fatalWriteErr(n, err) { + _, err = c.conn.Write(buf) + if err != nil { rows.fatal(err) c.die(err) return rows, err From 74ea479b0ca6b4fa850ac4ebf0ed36cd481c8d6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?PLATEL=20K=C3=A9vin?= Date: Tue, 5 Feb 2019 11:01:22 +0100 Subject: [PATCH 05/70] Close issue #481 : Give access to the registered driver instance Some library use a driver to wrap its behavior and give additional functionality, as the datadog tracing library ("gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql") This commit aims to give access to this instance which can't be correctly initialized to due private fields without default values (the configuration map inside the driver) --- stdlib/sql.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/stdlib/sql.go b/stdlib/sql.go index b83e527b..ec5933f3 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -126,6 +126,12 @@ var ( fakeTxConns map[*pgx.Conn]*sql.Tx ) +// GetDefaultDriver return the driver initialize in the init function +// and used when register pgx driver +func GetDefaultDriver() *Driver { + return pgxDriver +} + type Driver struct { configMutex sync.Mutex configCount int64 From 8fe19f698b7144ef785a0f058d72392feab1d0a8 Mon Sep 17 00:00:00 2001 From: Ilya Sinelnikov Date: Thu, 28 Feb 2019 19:02:16 +0300 Subject: [PATCH 06/70] Fix PreferSimpleProtocol overwrite https://github.com/jackc/pgx/issues/495 --- conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn.go b/conn.go index f8016220..1693d159 100644 --- a/conn.go +++ b/conn.go @@ -707,7 +707,7 @@ func (old ConnConfig) Merge(other ConnConfig) ConnConfig { cc.Dial = other.Dial } - cc.PreferSimpleProtocol = other.PreferSimpleProtocol + cc.PreferSimpleProtocol = old.PreferSimpleProtocol || other.PreferSimpleProtocol cc.RuntimeParams = make(map[string]string) for k, v := range old.RuntimeParams { From 1edfd3b6820d11edd267c611f35a827892bb8d77 Mon Sep 17 00:00:00 2001 From: Jan Vcelak Date: Thu, 7 Mar 2019 12:32:26 +0100 Subject: [PATCH 07/70] Allow normal queries on replication connections The replication connection allows executing most of the SQL queries which are available on non-replication connections. --- replication.go | 76 +++++++++++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/replication.go b/replication.go index 2efdcc79..a017dcc4 100644 --- a/replication.go +++ b/replication.go @@ -168,18 +168,18 @@ func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) { if err != nil { return } - return &ReplicationConn{c: c}, nil + return &ReplicationConn{c}, nil } type ReplicationConn struct { - c *Conn + *Conn } // Send standby status to the server, which both acts as a keepalive // message to the server, as well as carries the WAL position of the // client, which then updates the server's replication slot position. func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { - buf := rc.c.wbuf + buf := rc.wbuf buf = append(buf, copyData) sp := len(buf) buf = pgio.AppendInt32(buf, -1) @@ -193,46 +193,46 @@ func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = rc.c.conn.Write(buf) + _, err = rc.conn.Write(buf) if err != nil { - rc.c.die(err) + rc.die(err) } return } func (rc *ReplicationConn) Close() error { - return rc.c.Close() + return rc.Close() } func (rc *ReplicationConn) IsAlive() bool { - return rc.c.IsAlive() + return rc.IsAlive() } func (rc *ReplicationConn) CauseOfDeath() error { - return rc.c.CauseOfDeath() + return rc.CauseOfDeath() } func (rc *ReplicationConn) GetConnInfo() *pgtype.ConnInfo { - return rc.c.ConnInfo + return rc.ConnInfo } func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { - msg, err := rc.c.rxMsg() + msg, err := rc.rxMsg() if err != nil { return } switch msg := msg.(type) { case *pgproto3.NoticeResponse: - pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg)) - if rc.c.shouldLog(LogLevelInfo) { - rc.c.log(LogLevelInfo, pgError.Error(), nil) + pgError := rc.rxErrorResponse((*pgproto3.ErrorResponse)(msg)) + if rc.shouldLog(LogLevelInfo) { + rc.log(LogLevelInfo, pgError.Error(), nil) } case *pgproto3.ErrorResponse: - err = rc.c.rxErrorResponse(msg) - if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, err.Error(), nil) + err = rc.rxErrorResponse(msg) + if rc.shouldLog(LogLevelError) { + rc.log(LogLevelError, err.Error(), nil) } return case *pgproto3.CopyBothResponse: @@ -269,13 +269,13 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow} return &ReplicationMessage{ServerHeartbeat: h}, nil default: - if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType}) + if rc.shouldLog(LogLevelError) { + rc.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType}) } } default: - if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg}) + if rc.shouldLog(LogLevelError) { + rc.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg}) } } return @@ -300,12 +300,12 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl go func() { select { case <-ctx.Done(): - if err := rc.c.conn.SetDeadline(time.Now()); err != nil { + if err := rc.conn.SetDeadline(time.Now()); err != nil { rc.Close() // Close connection if unable to set deadline return } - rc.c.closedChan <- ctx.Err() - case <-rc.c.doneChan: + rc.closedChan <- ctx.Err() + case <-rc.doneChan: } }() @@ -313,8 +313,8 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl var err error select { - case err = <-rc.c.closedChan: - if err := rc.c.conn.SetDeadline(time.Time{}); err != nil { + case err = <-rc.closedChan: + if err := rc.conn.SetDeadline(time.Time{}); err != nil { rc.Close() // Close connection if unable to disable deadline return nil, err } @@ -322,7 +322,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl if opErr == nil { err = nil } - case rc.c.doneChan <- struct{}{}: + case rc.doneChan <- struct{}{}: err = opErr } @@ -330,34 +330,34 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl } func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { - rc.c.lastActivityTime = time.Now() + rc.lastActivityTime = time.Now() - rows := rc.c.getRows(sql, nil) + rows := rc.getRows(sql, nil) - if err := rc.c.lock(); err != nil { + if err := rc.lock(); err != nil { rows.fatal(err) return rows, err } rows.unlockConn = true - err := rc.c.sendSimpleQuery(sql) + err := rc.sendSimpleQuery(sql) if err != nil { rows.fatal(err) } - msg, err := rc.c.rxMsg() + msg, err := rc.rxMsg() if err != nil { return nil, err } switch msg := msg.(type) { case *pgproto3.RowDescription: - rows.fields = rc.c.rxRowDescription(msg) + rows.fields = rc.rxRowDescription(msg) // We don't have c.PgTypes here because we're a replication // connection. This means the field descriptions will have // only OIDs. Not much we can do about this. default: - if e := rc.c.processContextFreeMsg(msg); e != nil { + if e := rc.processContextFreeMsg(msg); e != nil { rows.fatal(e) return rows, e } @@ -417,7 +417,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti queryString += fmt.Sprintf(" ( %s )", strings.Join(pluginArguments, ", ")) } - if err = rc.c.sendQuery(queryString); err != nil { + if err = rc.sendQuery(queryString); err != nil { return } @@ -431,8 +431,8 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti var r *ReplicationMessage r, err = rc.WaitForReplicationMessage(ctx) if err != nil && r != nil { - if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err}) + if rc.shouldLog(LogLevelError) { + rc.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err}) } } @@ -441,7 +441,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti // Create the replication slot, using the given name and output plugin. func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) { - _, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin)) + _, err = rc.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin)) return } @@ -459,6 +459,6 @@ func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string // Drop the replication slot for the given name func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) { - _, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) + _, err = rc.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) return } From e08a188515546234166a30756f05cedd04019ac4 Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Sun, 10 Mar 2019 23:38:11 -0700 Subject: [PATCH 08/70] Fix enum handling --- values.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/values.go b/values.go index 0c571d74..e7e6c1f7 100644 --- a/values.go +++ b/values.go @@ -189,7 +189,15 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype sp := len(buf) buf = pgio.AppendInt32(buf, -1) - argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) + var argBuf []byte + switch valueEncoder := value.(type) { + case pgtype.BinaryEncoder: + argBuf, err = valueEncoder.EncodeBinary(ci, buf) + case pgtype.TextEncoder: + argBuf, err = valueEncoder.EncodeText(ci, buf) + default: + return nil, fmt.Errorf("invalid encode type %v", valueEncoder) + } if err != nil { return nil, err } From 9a3e403bdf10b1c850c0fc6bccad52415741d0b5 Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Sun, 10 Mar 2019 23:38:34 -0700 Subject: [PATCH 09/70] Add rudementary enum transcode test --- values_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/values_test.go b/values_test.go index ddaf5468..e05dab37 100644 --- a/values_test.go +++ b/values_test.go @@ -50,6 +50,29 @@ func TestDateTranscode(t *testing.T) { } } +func TestEnumTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + _, err := conn.Exec("create type some_type as enum ('hello-world', 'goodbye-world')") + if err != nil { + t.Fatalf("Unexpected failure in test setup: %v", err) + } + defer conn.Exec("drop type some_type") + + var out string + var actual = "hello-world" + err = conn.QueryRow("select $1::some_type", actual).Scan(&out) + if err != nil { + t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) + } + if actual != out { + t.Errorf("Did not transcode enum successfully: %s is not %s", out, actual) + } +} + func TestTimestampTzTranscode(t *testing.T) { t.Parallel() From 0a8645df197181e4d424e090c9c7f42175b035cd Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Mon, 11 Mar 2019 00:14:06 -0700 Subject: [PATCH 10/70] Remove test --- values_test.go | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/values_test.go b/values_test.go index e05dab37..ddaf5468 100644 --- a/values_test.go +++ b/values_test.go @@ -50,29 +50,6 @@ func TestDateTranscode(t *testing.T) { } } -func TestEnumTranscode(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - _, err := conn.Exec("create type some_type as enum ('hello-world', 'goodbye-world')") - if err != nil { - t.Fatalf("Unexpected failure in test setup: %v", err) - } - defer conn.Exec("drop type some_type") - - var out string - var actual = "hello-world" - err = conn.QueryRow("select $1::some_type", actual).Scan(&out) - if err != nil { - t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) - } - if actual != out { - t.Errorf("Did not transcode enum successfully: %s is not %s", out, actual) - } -} - func TestTimestampTzTranscode(t *testing.T) { t.Parallel() From 03c00d5e41cfb7d586e979982cdf594757e78c1c Mon Sep 17 00:00:00 2001 From: Lukas Vogel Date: Mon, 18 Mar 2019 14:20:19 +0100 Subject: [PATCH 11/70] Remove unreachable code The returns can never be reached because the loop is guaranteed to return. --- copy_from.go | 2 -- copy_to.go | 2 -- 2 files changed, 4 deletions(-) diff --git a/copy_from.go b/copy_from.go index 27e2fc9a..d4b594d9 100644 --- a/copy_from.go +++ b/copy_from.go @@ -334,6 +334,4 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error { return c.processContextFreeMsg(msg) } } - - return nil } diff --git a/copy_to.go b/copy_to.go index 0e11a6ed..f6b8d361 100644 --- a/copy_to.go +++ b/copy_to.go @@ -59,6 +59,4 @@ func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error return c.processContextFreeMsg(msg) } } - - return nil } From bbe778863fe8895333bb11ecb849cd2646771723 Mon Sep 17 00:00:00 2001 From: Jan Vcelak Date: Mon, 18 Mar 2019 18:52:52 +0100 Subject: [PATCH 12/70] Remove Conn methods on ReplicationConn Fixes infinite loop when any of the removed methods is called. --- replication.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/replication.go b/replication.go index a017dcc4..d3fdd046 100644 --- a/replication.go +++ b/replication.go @@ -201,18 +201,6 @@ func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { return } -func (rc *ReplicationConn) Close() error { - return rc.Close() -} - -func (rc *ReplicationConn) IsAlive() bool { - return rc.IsAlive() -} - -func (rc *ReplicationConn) CauseOfDeath() error { - return rc.CauseOfDeath() -} - func (rc *ReplicationConn) GetConnInfo() *pgtype.ConnInfo { return rc.ConnInfo } From 038060776bdd85d61c2138e4ace990c2bfc742be Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Mar 2019 11:22:47 -0500 Subject: [PATCH 13/70] Use LogLevel type instead of int for conn config Technically, this is a change in the public interface. But it seems extremely unlikely that it would cause any issues (and any that do appear would be trivial to fix). fixes #516 --- conn.go | 8 ++++---- conn_pool.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/conn.go b/conn.go index 1693d159..d500b132 100644 --- a/conn.go +++ b/conn.go @@ -72,7 +72,7 @@ type ConnConfig struct { UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS Logger Logger - LogLevel int + LogLevel LogLevel Dial DialFunc RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) OnNotice NoticeHandler // Callback function called when a notice response is received. @@ -123,7 +123,7 @@ type Conn struct { channels map[string]struct{} notifications []*Notification logger Logger - logLevel int + logLevel LogLevel fp *fastpath poolResetCount int preallocatedRows []Rows @@ -1609,7 +1609,7 @@ func (c *Conn) unlock() error { return nil } -func (c *Conn) shouldLog(lvl int) bool { +func (c *Conn) shouldLog(lvl LogLevel) bool { return c.logger != nil && c.logLevel >= lvl } @@ -1633,7 +1633,7 @@ func (c *Conn) SetLogger(logger Logger) Logger { // SetLogLevel replaces the current log level and returns the previous log // level. -func (c *Conn) SetLogLevel(lvl int) (int, error) { +func (c *Conn) SetLogLevel(lvl LogLevel) (LogLevel, error) { oldLvl := c.logLevel if lvl < LogLevelNone || lvl > LogLevelTrace { diff --git a/conn_pool.go b/conn_pool.go index 77450dba..af4c879e 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -28,7 +28,7 @@ type ConnPool struct { resetCount int afterConnect func(*Conn) error logger Logger - logLevel int + logLevel LogLevel closed bool preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration From 2e26d8df0374db517f7d21d4886601a82161567a Mon Sep 17 00:00:00 2001 From: Jan Vcelak Date: Mon, 25 Mar 2019 13:47:48 +0100 Subject: [PATCH 14/70] Document simple protocol on ReplicationConn --- replication.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/replication.go b/replication.go index d3fdd046..52e6b915 100644 --- a/replication.go +++ b/replication.go @@ -171,6 +171,17 @@ func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) { return &ReplicationConn{c}, nil } +// ReplicationConn is a PostgreSQL connection handle established in the +// replication mode which enables a special set of commands for streaming WAL +// changes from the server. +// +// When in replication mode, only the simple query protocol can be used +// (see PreferSimpleProtocol in ConnConfig). Execution of normal SQL queries on +// the connection is possible but may be limited in available functionality. +// Most notably, prepared statements won't work. +// +// See https://www.postgresql.org/docs/11/protocol-replication.html for +// details. type ReplicationConn struct { *Conn } From 0b62f832b064fae8018a522ea8d5665033749a02 Mon Sep 17 00:00:00 2001 From: fzerorubigd Date: Thu, 28 Mar 2019 16:31:55 +0100 Subject: [PATCH 15/70] [stdlib] Add support for creating a DB from pgx.Pool Also the configuration used in the Conn structure (used to implement the driver.Conn interface) stores a ConnConfig which is used only for determining if the Connection should be used with Simple Protocol or not. --- stdlib/opendbpool.go | 60 ++++++++++++++++++++++++++++++++++++ stdlib/stdlibutil110_test.go | 9 +++++- 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 stdlib/opendbpool.go diff --git a/stdlib/opendbpool.go b/stdlib/opendbpool.go new file mode 100644 index 00000000..dd7596ea --- /dev/null +++ b/stdlib/opendbpool.go @@ -0,0 +1,60 @@ +// +build go1.10 + +package stdlib + +import ( + "context" + "database/sql" + "database/sql/driver" + + "github.com/jackc/pgx" +) + +// OptionOpenDB options for configuring the driver when opening a new db pool. +type OptionOpenDBFromPool func(*poolConnector) + +// OptionAfterConnect provide a callback for after connect. +func OptionPreferSimpleProtocol(preferSimpleProtocol bool) OptionOpenDBFromPool { + return func(dc *poolConnector) { + dc.preferSimpleProtocol = preferSimpleProtocol + } +} + +// OpenDBFromPool create a sql.DB connection from a pgx.ConnPool +func OpenDBFromPool(pool *pgx.ConnPool, opts ...OptionOpenDBFromPool) *sql.DB { + c := poolConnector{ + pool: pool, + driver: pgxDriver, + } + + for _, opt := range opts { + opt(&c) + } + + return sql.OpenDB(c) +} + +type poolConnector struct { + pool *pgx.ConnPool + driver *Driver + preferSimpleProtocol bool +} + +// Connect implement driver.Connector interface +func (pc poolConnector) Connect(ctx context.Context) (driver.Conn, error) { + var ( + err error + conn *pgx.Conn + ) + + if conn, err = pc.pool.Acquire(); err != nil { + return nil, err + } + + return &Conn{conn: conn, driver: pc.driver, connConfig: pgx.ConnConfig{PreferSimpleProtocol: pc.preferSimpleProtocol}}, nil +} + +// Driver implement driver.Connector interface +func (pc poolConnector) Driver() driver.Driver { + return pc.driver +} diff --git a/stdlib/stdlibutil110_test.go b/stdlib/stdlibutil110_test.go index c83b645b..8a5f209c 100644 --- a/stdlib/stdlibutil110_test.go +++ b/stdlib/stdlibutil110_test.go @@ -16,5 +16,12 @@ func openDB(t *testing.T) *sql.DB { t.Fatalf("pgx.ParseConnectionString failed: %v", err) } - return stdlib.OpenDB(config) + pool, err := pgx.NewConnPool(pgx.ConnPoolConfig{ + ConnConfig: config, + }) + + if err != nil { + t.Fatalf("pgx.ParseConnectionString failed: %v", err) + } + return stdlib.OpenDBFromPool(pool) } From 3e82824ff1638c132ee5066133c3ce89b17d63d9 Mon Sep 17 00:00:00 2001 From: Jan Vcelak Date: Sun, 31 Mar 2019 20:29:37 +0200 Subject: [PATCH 16/70] Enforce simple protocol on ReplicationConn --- replication.go | 1 + replication_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/replication.go b/replication.go index 52e6b915..14895ecf 100644 --- a/replication.go +++ b/replication.go @@ -163,6 +163,7 @@ func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) { config.RuntimeParams = make(map[string]string) } config.RuntimeParams["replication"] = "database" + config.PreferSimpleProtocol = true c, err := Connect(config) if err != nil { diff --git a/replication_test.go b/replication_test.go index d06d73cd..54ac2b4a 100644 --- a/replication_test.go +++ b/replication_test.go @@ -343,3 +343,28 @@ func TestStandbyStatusParsing(t *testing.T) { t.Errorf("Unexpected write position %d", status.WalWritePosition) } } + +func TestSimpleProtocolEnforcement(t *testing.T) { + if replicationConnConfig == nil { + t.Skip("Skipping due to undefined replicationConnConfig") + } + + replicationConn := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn) + + query := "select count(*) from pg_replication_slots" + + // Check that the simple query protocol is used by default + rows, err := replicationConn.Query(query) + if err != nil { + t.Fatalf("Query failed: %v", err) + } + rows.Close() + + // Check that using the extended query protocol will fail + rows, err = replicationConn.QueryEx(context.Background(), query, &pgx.QueryExOptions{SimpleProtocol: false}) + if err == nil { + t.Fatal("Query expected to fail.") + } + rows.Close() +} From 5044e8473ad948114b6cb63f6f30f94fc7834667 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 16 Apr 2019 20:46:23 -0500 Subject: [PATCH 17/70] Add SCRAM authentication --- auth_scram.go | 255 ++++++++++++++++++++++++++++++ conn.go | 2 + conn_test.go | 50 ++++++ pgproto3/authentication.go | 30 ++++ pgproto3/sasl_initial_response.go | 64 ++++++++ pgproto3/sasl_response.go | 38 +++++ 6 files changed, 439 insertions(+) create mode 100644 auth_scram.go create mode 100644 pgproto3/sasl_initial_response.go create mode 100644 pgproto3/sasl_response.go diff --git a/auth_scram.go b/auth_scram.go new file mode 100644 index 00000000..8ac8a82b --- /dev/null +++ b/auth_scram.go @@ -0,0 +1,255 @@ +// SCRAM-SHA-256 authentication +// +// Resources: +// https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc8265 +// https://www.postgresql.org/docs/current/sasl-authentication.html +// +// Inspiration drawn from other implementations: +// https://github.com/lib/pq/pull/608 +// https://github.com/lib/pq/pull/788 +// https://github.com/lib/pq/pull/833 +package pgx + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "strconv" + + "github.com/jackc/pgx/pgproto3" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/text/secure/precis" +) + +const clientNonceLen = 18 + +// Perform SCRAM authentication. +func (c *Conn) scramAuth(serverAuthMechanisms []string) error { + sc, err := newScramClient(serverAuthMechanisms, c.config.Password) + if err != nil { + return err + } + + // Send client-first-message in a SASLInitialResponse + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: sc.clientFirstMessage(), + } + _, err = c.conn.Write(saslInitialResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-first-message payload in a AuthenticationSASLContinue. + authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue) + if err != nil { + return err + } + err = sc.recvServerFirstMessage(authMsg.SASLData) + if err != nil { + return err + } + + // Send client-final-message in a SASLResponse + saslResponse := &pgproto3.SASLResponse{ + Data: []byte(sc.clientFinalMessage()), + } + _, err = c.conn.Write(saslResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-final-message payload in a AuthenticationSASLFinal. + authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal) + if err != nil { + return err + } + return sc.recvServerFinalMessage(authMsg.SASLData) +} + +func (c *Conn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { + msg, err := c.rxMsg() + if err != nil { + return nil, err + } + authMsg, ok := msg.(*pgproto3.Authentication) + if !ok { + return nil, errors.New("unexpected message type") + } + if authMsg.Type != typ { + return nil, errors.New("unexpected auth type") + } + + return authMsg, nil +} + +type scramClient struct { + serverAuthMechanisms []string + password []byte + clientNonce []byte + + clientFirstMessageBare []byte + + serverFirstMessage []byte + clientAndServerNonce []byte + salt []byte + iterations int + + saltedPassword []byte + authMessage []byte +} + +func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { + sc := &scramClient{ + serverAuthMechanisms: serverAuthMechanisms, + } + + // Ensure server supports SCRAM-SHA-256 + hasScramSHA256 := false + for _, mech := range sc.serverAuthMechanisms { + if mech == "SCRAM-SHA-256" { + hasScramSHA256 = true + break + } + } + if !hasScramSHA256 { + return nil, errors.New("server does not support SCRAM-SHA-256") + } + + // precis.OpaqueString is equivalent to SASLprep for password. + var err error + sc.password, err = precis.OpaqueString.Bytes([]byte(password)) + if err != nil { + // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. + sc.password = []byte(password) + } + + buf := make([]byte, clientNonceLen) + _, err = rand.Read(buf) + if err != nil { + return nil, err + } + sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) + base64.RawStdEncoding.Encode(sc.clientNonce, buf) + + return sc, nil +} + +func (sc *scramClient) clientFirstMessage() []byte { + sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) + return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) +} + +func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { + sc.serverFirstMessage = serverFirstMessage + buf := serverFirstMessage + if !bytes.HasPrefix(buf, []byte("r=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include r=") + } + buf = buf[2:] + + idx := bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + sc.clientAndServerNonce = buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("s=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + buf = buf[2:] + + idx = bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + saltStr := buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("i=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + buf = buf[2:] + iterationsStr := buf + + var err error + sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) + if err != nil { + return fmt.Errorf("invalid SCRAM salt received from server: %v", err) + } + + sc.iterations, err = strconv.Atoi(string(iterationsStr)) + if err != nil || sc.iterations <= 0 { + return fmt.Errorf("invalid SCRAM iteration count received from server: %s", iterationsStr) + } + + if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not start with client nonce") + } + + if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not include server nonce") + } + + return nil +} + +func (sc *scramClient) clientFinalMessage() string { + clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + + sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) + sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) + + clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) + + return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) +} + +func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { + if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { + return errors.New("invalid SCRAM server-final-message received from server") + } + + serverSignature := serverFinalMessage[2:] + + if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { + return errors.New("invalid SCRAM ServerSignature received from server") + } + + return nil +} + +func computeHMAC(key, msg []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(msg) + return mac.Sum(nil) +} + +func computeClientProof(saltedPassword, authMessage []byte) []byte { + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + storedKey := sha256.Sum256(clientKey) + clientSignature := computeHMAC(storedKey[:], authMessage) + + clientProof := make([]byte, len(clientSignature)) + for i := 0; i < len(clientSignature); i++ { + clientProof[i] = clientKey[i] ^ clientSignature[i] + } + + buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) + base64.StdEncoding.Encode(buf, clientProof) + return buf +} + +func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + serverSignature := computeHMAC(serverKey[:], authMessage) + buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) + base64.StdEncoding.Encode(buf, serverSignature) + return buf +} diff --git a/conn.go b/conn.go index d500b132..cb24748c 100644 --- a/conn.go +++ b/conn.go @@ -1425,6 +1425,8 @@ func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { case pgproto3.AuthTypeMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:])) err = c.txPasswordMessage(digestedPassword) + case pgproto3.AuthTypeSASL: + err = c.scramAuth(msg.SASLAuthMechanisms) default: err = errors.New("Received unknown authentication message") } diff --git a/conn_test.go b/conn_test.go index c745d392..6ca00c6d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -212,6 +212,56 @@ func TestConnectWithMD5Password(t *testing.T) { } } +func TestConnectWithSCRAMPassword(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_SCRAM_PASSWORD_CONN_STRING") + if connString == "" { + t.Skip("Skipping due to missing PGX_TEST_SCRAM_PASSWORD_CONN_STRING env var") + } + + connConfig, err := pgx.ParseConnectionString(connString) + if err != nil { + t.Fatalf("Unable to parse config: %v", err) + } + + conn, err := pgx.Connect(connConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err = conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithTLSFallback(t *testing.T) { t.Parallel() diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go index 77750b86..5f698d0c 100644 --- a/pgproto3/authentication.go +++ b/pgproto3/authentication.go @@ -1,6 +1,7 @@ package pgproto3 import ( + "bytes" "encoding/binary" "github.com/jackc/pgx/pgio" @@ -11,6 +12,9 @@ const ( AuthTypeOk = 0 AuthTypeCleartextPassword = 3 AuthTypeMD5Password = 5 + AuthTypeSASL = 10 + AuthTypeSASLContinue = 11 + AuthTypeSASLFinal = 12 ) type Authentication struct { @@ -18,6 +22,12 @@ type Authentication struct { // MD5Password fields Salt [4]byte + + // SASL fields + SASLAuthMechanisms []string + + // SASLContinue and SASLFinal data + SASLData []byte } func (*Authentication) Backend() {} @@ -30,6 +40,17 @@ func (dst *Authentication) Decode(src []byte) error { case AuthTypeCleartextPassword: case AuthTypeMD5Password: copy(dst.Salt[:], src[4:8]) + case AuthTypeSASL: + authMechanisms := src[4:] + for len(authMechanisms) > 1 { + idx := bytes.IndexByte(authMechanisms, 0) + if idx > 0 { + dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx])) + authMechanisms = authMechanisms[idx+1:] + } + } + case AuthTypeSASLContinue, AuthTypeSASLFinal: + dst.SASLData = src[4:] default: return errors.Errorf("unknown authentication type: %d", dst.Type) } @@ -46,6 +67,15 @@ func (src *Authentication) Encode(dst []byte) []byte { switch src.Type { case AuthTypeMD5Password: dst = append(dst, src.Salt[:]...) + case AuthTypeSASL: + for _, s := range src.SASLAuthMechanisms { + dst = append(dst, []byte(s)...) + dst = append(dst, 0) + } + dst = append(dst, 0) + case AuthTypeSASLContinue: + dst = pgio.AppendInt32(dst, int32(len(src.SASLData))) + dst = append(dst, src.SASLData...) } pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) diff --git a/pgproto3/sasl_initial_response.go b/pgproto3/sasl_initial_response.go new file mode 100644 index 00000000..f58a6b56 --- /dev/null +++ b/pgproto3/sasl_initial_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "errors" + + "github.com/jackc/pgx/pgio" +) + +type SASLInitialResponse struct { + AuthMechanism string + Data []byte +} + +func (*SASLInitialResponse) Frontend() {} + +func (dst *SASLInitialResponse) Decode(src []byte) error { + *dst = SASLInitialResponse{} + + rp := 0 + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return errors.New("invalid SASLInitialResponse") + } + + dst.AuthMechanism = string(src[rp:idx]) + rp = idx + 1 + + rp += 4 // The rest of the message is data so we can just skip the size + dst.Data = src[rp:] + + return nil +} + +func (src *SASLInitialResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, []byte(src.AuthMechanism)...) + dst = append(dst, 0) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *SASLInitialResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanism string + Data string + }{ + Type: "SASLInitialResponse", + AuthMechanism: src.AuthMechanism, + Data: hex.EncodeToString(src.Data), + }) +} diff --git a/pgproto3/sasl_response.go b/pgproto3/sasl_response.go new file mode 100644 index 00000000..ed96686b --- /dev/null +++ b/pgproto3/sasl_response.go @@ -0,0 +1,38 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type SASLResponse struct { + Data []byte +} + +func (*SASLResponse) Frontend() {} + +func (dst *SASLResponse) Decode(src []byte) error { + *dst = SASLResponse{Data: src} + return nil +} + +func (src *SASLResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + + dst = append(dst, src.Data...) + + return dst +} + +func (src *SASLResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "SASLResponse", + Data: hex.EncodeToString(src.Data), + }) +} From 53dd8bf77c58e97e0bf33eea792ea211af48d2ae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 16 Apr 2019 21:29:41 -0500 Subject: [PATCH 18/70] Travis fix --- travis/install.bash | 2 ++ 1 file changed, 2 insertions(+) diff --git a/travis/install.bash b/travis/install.bash index 63ba875d..3c3e44cf 100755 --- a/travis/install.bash +++ b/travis/install.bash @@ -12,3 +12,5 @@ go get -u github.com/sirupsen/logrus go get -u github.com/pkg/errors go get -u go.uber.org/zap go get -u github.com/rs/zerolog +go get -u golang.org/x/crypto/pbkdf2 +go get -u golang.org/x/text/secure/precis From 2492eae46cfd851a34912f64421119039d5f4299 Mon Sep 17 00:00:00 2001 From: Andrey Kuzmin Date: Mon, 22 Apr 2019 00:22:22 +0300 Subject: [PATCH 19/70] Support for pgtype.Date JSON marshal/unmarshal. JSON marshalling for types added on a as-needed basis. Partly closes https://github.com/jackc/pgx/issues/310. --- pgtype/date.go | 29 +++++++++++++++++++++++++++++ pgtype/date_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/pgtype/date.go b/pgtype/date.go index b1d4c11d..658dbf1c 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "time" "github.com/jackc/pgx/pgio" @@ -207,3 +208,31 @@ func (src *Date) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src *Date) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + res, err := src.Time.MarshalJSON() + if err != nil { + return nil, err + } + return res, nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *Date) UnmarshalJSON(b []byte) error { + var n time.Time + if err := json.Unmarshal(b, &n); err != nil { + return err + } + + *dst = Date{Time: n, Status: Present} + + return nil +} diff --git a/pgtype/date_test.go b/pgtype/date_test.go index d98e1652..178a1ff9 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -116,3 +116,30 @@ func TestDateAssignTo(t *testing.T) { } } } + +func TestMarshalJSON(t *testing.T) { + r := pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present} + enc, err := r.MarshalJSON() + if err != nil { + t.Errorf("%v", err) + return + } + + if string(enc) != "\"1900-01-01T00:00:00Z\"" { + t.Errorf("Incorrect json marshal") + } +} + +func TestUnmarshalJSON(t *testing.T) { + var r pgtype.Date + tm := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) + + if err := r.UnmarshalJSON([]byte(`"` + tm.Format(time.RFC3339) + `"`)); err != nil { + t.Errorf("%v", err) + return + } + + if tm != r.Time { + t.Errorf("Incorrect json unmarshal") + } +} From b4c77819daba7b68ffad9dfcd2d45f55c87937d1 Mon Sep 17 00:00:00 2001 From: Andrey Kuzmin Date: Tue, 23 Apr 2019 21:13:32 +0300 Subject: [PATCH 20/70] Use date as date, not datetime. Marshal/unmarshal date without time part. Date is postgresql type without time. --- pgtype/date.go | 13 +++++-------- pgtype/date_test.go | 8 +++----- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/pgtype/date.go b/pgtype/date.go index 658dbf1c..3c81a32b 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "encoding/json" + "fmt" "time" "github.com/jackc/pgx/pgio" @@ -212,11 +212,8 @@ func (src *Date) Value() (driver.Value, error) { func (src *Date) MarshalJSON() ([]byte, error) { switch src.Status { case Present: - res, err := src.Time.MarshalJSON() - if err != nil { - return nil, err - } - return res, nil + s := fmt.Sprintf("%q", src.Time.Format("2006-01-02")) + return []byte(s), nil case Null: return []byte("null"), nil case Undefined: @@ -227,8 +224,8 @@ func (src *Date) MarshalJSON() ([]byte, error) { } func (dst *Date) UnmarshalJSON(b []byte) error { - var n time.Time - if err := json.Unmarshal(b, &n); err != nil { + n, err := time.Parse("\"2006-01-02\"", string(b)) + if err != nil { return err } diff --git a/pgtype/date_test.go b/pgtype/date_test.go index 178a1ff9..8128d87d 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -125,21 +125,19 @@ func TestMarshalJSON(t *testing.T) { return } - if string(enc) != "\"1900-01-01T00:00:00Z\"" { + if string(enc) != "\"1900-01-01\"" { t.Errorf("Incorrect json marshal") } } func TestUnmarshalJSON(t *testing.T) { var r pgtype.Date - tm := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) - - if err := r.UnmarshalJSON([]byte(`"` + tm.Format(time.RFC3339) + `"`)); err != nil { + if err := r.UnmarshalJSON([]byte("\"1900-01-01\"")); err != nil { t.Errorf("%v", err) return } - if tm != r.Time { + if r.Time.Year() != 1900 || r.Time.Month() != 1 || r.Time.Day() != 1 { t.Errorf("Incorrect json unmarshal") } } From fbb8ccee58d435e80d5391a9334c8d9e45ff11c5 Mon Sep 17 00:00:00 2001 From: Andrey Date: Sun, 28 Apr 2019 23:11:59 +0500 Subject: [PATCH 21/70] Fix few issues with copy command --- pgproto3/copy_fail.go | 48 +++++++++++++++++++++++++++++++++++ pgproto3/copy_out_response.go | 2 ++ pgproto3/frontend.go | 3 +++ 3 files changed, 53 insertions(+) create mode 100644 pgproto3/copy_fail.go diff --git a/pgproto3/copy_fail.go b/pgproto3/copy_fail.go new file mode 100644 index 00000000..c49821ce --- /dev/null +++ b/pgproto3/copy_fail.go @@ -0,0 +1,48 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type CopyFail struct { + Message string +} + +func (*CopyFail) Backend() {} + +func (dst *CopyFail) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CopyFail"} + } + + dst.Message = string(src[:idx]) + + return nil +} + +func (src *CopyFail) Encode(dst []byte) []byte { + dst = append(dst, 'f') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Message...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *CopyFail) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Message string + }{ + Type: "CopyFail", + Message: src.Message, + }) +} diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index eaa33b8b..561eaeed 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -44,6 +44,8 @@ func (src *CopyOutResponse) Encode(dst []byte) []byte { sp := len(dst) dst = pgio.AppendInt32(dst, -1) + dst = append(dst, src.OverallFormat) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index d1541c74..4c05fcc3 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -23,6 +23,7 @@ type Frontend struct { copyInResponse CopyInResponse copyOutResponse CopyOutResponse copyDone CopyDone + copyFail CopyFail dataRow DataRow emptyQueryResponse EmptyQueryResponse errorResponse ErrorResponse @@ -75,6 +76,8 @@ func (b *Frontend) Receive() (BackendMessage, error) { msg = &b.notificationResponse case 'c': msg = &b.copyDone + case 'f': + msg = &b.copyFail case 'C': msg = &b.commandComplete case 'd': From 48df34cc07347044b027bf410bbe078498d4b4b4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 May 2019 14:23:11 -0500 Subject: [PATCH 22/70] Fix inadvertent package doc --- auth_scram.go | 1 + 1 file changed, 1 insertion(+) diff --git a/auth_scram.go b/auth_scram.go index 8ac8a82b..2c80125c 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -9,6 +9,7 @@ // https://github.com/lib/pq/pull/608 // https://github.com/lib/pq/pull/788 // https://github.com/lib/pq/pull/833 + package pgx import ( From 8faa4453fc7051d1076053f8854077753ab912f2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 May 2019 15:52:30 -0500 Subject: [PATCH 23/70] Update changelog for 3.4.0 --- CHANGELOG.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 305c2a36..27c7d9df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,30 @@ +# 3.4.0 (May 3, 2019) + +## Features + +* Improved .pgpass handling (Dmitry Smal) +* Adds RowsAffected for CopyToWriter and CopyFromReader (Nikolay Vorobev) +* Support binding of []int type to array integer (David Bariod) +* Expose registered driver instance to aid integration with other libraries (PLATEL Kévin) +* Allow normal queries on replication connections (Jan Vcelak) +* Add support for creating a DB from pgx.Pool (fzerorubigd) +* SCRAM authentication +* pgtype.Date JSON marshal/unmarshal (Andrey Kuzmin) + +## Fixes + +* Fix encoding of ErrorResponse (Josh Leverette) +* Use more detailed error output of unknown field (Ilya Sivanev) +* "Temporary" Write errors no longer silently break connections. +* Fix PreferSimpleProtocol overwrite (Ilya Sinelnikov) +* Fix enum handling (Robert Lin) +* Copy protocol fixes (Andrey) + +## Changes + +* Do not attempt recovery from any Write error. +* Use LogLevel type instead of int for conn config + # 3.3.0 (December 1, 2018) ## Features From 56f4f0b9d319a910016ce044a53f52fcf986ddc6 Mon Sep 17 00:00:00 2001 From: Josh Leverette Date: Mon, 20 May 2019 11:30:25 -0700 Subject: [PATCH 24/70] Hstore can have empty keys --- pgtype/hstore.go | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 71b030f9..215adc03 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -296,13 +296,9 @@ func parseHstore(s string) (k []string, v []Text, err error) { case hsKey: switch r { case '"': //End of the key - if buf.Len() == 0 { - err = errors.New("Empty Key is invalid") - } else { - keys = append(keys, buf.String()) - buf = bytes.Buffer{} - state = hsSep - } + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep case '\\': //Potential escaped character n, end := p.Consume() switch { From 0ab6f80f9929384a8cf6cfc299b43233534eb705 Mon Sep 17 00:00:00 2001 From: avivklas Date: Mon, 3 Jun 2019 13:44:43 +0300 Subject: [PATCH 25/70] added PortalSuspended message --- pgproto3/frontend.go | 9 ++++++--- pgproto3/portal_suspended.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 pgproto3/portal_suspended.go diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 4c05fcc3..be2c01cd 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -23,7 +23,7 @@ type Frontend struct { copyInResponse CopyInResponse copyOutResponse CopyOutResponse copyDone CopyDone - copyFail CopyFail + copyFail CopyFail dataRow DataRow emptyQueryResponse EmptyQueryResponse errorResponse ErrorResponse @@ -36,6 +36,7 @@ type Frontend struct { parseComplete ParseComplete readyForQuery ReadyForQuery rowDescription RowDescription + portalSuspended PortalSuspended bodyLen int msgType byte @@ -76,8 +77,8 @@ func (b *Frontend) Receive() (BackendMessage, error) { msg = &b.notificationResponse case 'c': msg = &b.copyDone - case 'f': - msg = &b.copyFail + case 'f': + msg = &b.copyFail case 'C': msg = &b.commandComplete case 'd': @@ -112,6 +113,8 @@ func (b *Frontend) Receive() (BackendMessage, error) { msg = &b.copyBothResponse case 'Z': msg = &b.readyForQuery + case 's': + msg = &b.portalSuspended default: return nil, errors.Errorf("unknown message type: %c", b.msgType) } diff --git a/pgproto3/portal_suspended.go b/pgproto3/portal_suspended.go new file mode 100644 index 00000000..dc81b027 --- /dev/null +++ b/pgproto3/portal_suspended.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type PortalSuspended struct{} + +func (*PortalSuspended) Backend() {} + +func (dst *PortalSuspended) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "PortalSuspended", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *PortalSuspended) Encode(dst []byte) []byte { + return append(dst, 's', 0, 0, 0, 4) +} + +func (src *PortalSuspended) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "PortalSuspended", + }) +} From 9538d15c29005e5044da6ba3f4c8ff06daec1278 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Mon, 3 Jun 2019 23:51:48 +0300 Subject: [PATCH 26/70] Draft of connection writable checking Signed-off-by: Artemiy Ryabinkov --- conn.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index cb24748c..e0169d6d 100644 --- a/conn.go +++ b/conn.go @@ -89,6 +89,13 @@ type ConnConfig struct { // used by default. The same functionality can be controlled on a per query // basis by setting QueryExOptions.SimpleProtocol. PreferSimpleProtocol bool + + // TargetSessionAttr allows to specify which servers are accepted for this connection. + // "any", meaning that any kind of servers can be accepted. This is as well the default value. + // "read-write", to disallow connections to read-only servers, hot standbys for example. + // @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com + // @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/ + TargetSessionAttrs string } func (cc *ConnConfig) networkAddress() (network, address string) { @@ -262,8 +269,15 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } } + if c.config.TargetSessionAttrs != "" && + c.config.TargetSessionAttrs != "any" && + c.config.TargetSessionAttrs != "read-write" { + return nil, errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") + } + c.onNotice = config.OnNotice + // TODO: Parse multi-hosts network, address := c.config.networkAddress() if c.config.Dial == nil { d := defaultDialer() @@ -273,22 +287,58 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) } - err = c.connect(config, network, address, config.TLSConfig) - if err != nil && config.UseFallbackTLS { - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + + // TODO: Start loop for all hosts [host0 .. hostN] + for { + err = c.connect(config, network, address, config.TLSConfig) + if err != nil && config.UseFallbackTLS { + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + } + err = c.connect(config, network, address, config.FallbackTLSConfig) } - err = c.connect(config, network, address, config.FallbackTLSConfig) + + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) + } + // TODO: Collect error + continue + } + + err = c.writeable() + if err != nil { + // TODO: Log info about not writable host + // TODO: Collect error + continue + } + + return c, nil } + + // TODO: Return collected errors + return nil, nil +} + +func (c *Conn) writeable() error { + if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { + return nil + } + + var st string + err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}). + Scan(st) + if err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) - } - return nil, err + return errors.Wrap(err, "failed to fetch transaction_read_only state") } - return c, nil + if st == "on" { + return errors.New("writable connection disabled by server") + } + + return nil } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { @@ -709,6 +759,10 @@ func (old ConnConfig) Merge(other ConnConfig) ConnConfig { cc.PreferSimpleProtocol = old.PreferSimpleProtocol || other.PreferSimpleProtocol + if other.TargetSessionAttrs != "" { + cc.TargetSessionAttrs = other.TargetSessionAttrs + } + cc.RuntimeParams = make(map[string]string) for k, v := range old.RuntimeParams { cc.RuntimeParams[k] = v From d678216f468d1fe4dc28649feacd4b30a176769e Mon Sep 17 00:00:00 2001 From: David Hudson Date: Fri, 7 Jun 2019 15:08:36 +0100 Subject: [PATCH 27/70] pgtype: Fix -0 for numeric types Due to the special case of when the digits string was longer than 1 but only contained the negative sign and a 0, it was incorrectly stripping the 0 and attempting to parse "-" as a number. The solution is to check an extra position along to make sure a trailing 0 is not immediately preceeded by a negetive sign. Fixes #543 --- pgtype/numeric.go | 2 +- pgtype/numeric_array_test.go | 15 +++++++++++++++ pgtype/numeric_test.go | 3 +++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index fb63df75..e14d02e4 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -321,7 +321,7 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { if len(parts) > 1 { exp = int32(-len(parts[1])) } else { - for len(digits) > 1 && digits[len(digits)-1] == '0' { + for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' { digits = digits[:len(digits)-1] exp++ } diff --git a/pgtype/numeric_array_test.go b/pgtype/numeric_array_test.go index 22ee1bc4..28aa67d9 100644 --- a/pgtype/numeric_array_test.go +++ b/pgtype/numeric_array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "math" "math/big" "reflect" "testing" @@ -65,6 +66,13 @@ func TestNumericArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []float32{float32(math.Copysign(0, -1))}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(0), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: []float64{1}, result: pgtype.NumericArray{ @@ -72,6 +80,13 @@ func TestNumericArraySet(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, + { + source: []float64{math.Copysign(0, -1)}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(0), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, { source: (([]float32)(nil)), result: pgtype.NumericArray{Status: pgtype.Null}, diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 9d7d83d6..a5f70c9e 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "math" "math/big" "math/rand" "reflect" @@ -188,7 +189,9 @@ func TestNumericSet(t *testing.T) { result *pgtype.Numeric }{ {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float32(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present}}, {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(math.Copysign(0, -1)), result: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Present}}, {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, From 9f031bb8f9bea60bd51ebc1cbaaa8e5db779b191 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 16 Jun 2019 14:03:43 +0300 Subject: [PATCH 28/70] Return net.Addr from networkAddress Signed-off-by: Artemiy Ryabinkov --- conn.go | 82 +++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/conn.go b/conn.go index e0169d6d..323875a5 100644 --- a/conn.go +++ b/conn.go @@ -98,20 +98,24 @@ type ConnConfig struct { TargetSessionAttrs string } -func (cc *ConnConfig) networkAddress() (network, address string) { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - // See if host is a valid path, if yes connect with a socket +func (cc *ConnConfig) networkAddress() net.Addr { + // See if host is a valid path, if yes connect with a unix socket if _, err := os.Stat(cc.Host); err == nil { // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = cc.Host + network := "unix" + address := cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10) } + + return &net.UnixAddr{Name: address, Net: network} } - return network, address + return &net.TCPAddr{ + Port: int(cc.Port), + IP: net.ParseIP(cc.Host), + } } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -277,48 +281,80 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.onNotice = config.OnNotice - // TODO: Parse multi-hosts - network, address := c.config.networkAddress() if c.config.Dial == nil { d := defaultDialer() c.config.Dial = d.Dial } - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) - } + // TODO: Parse multi-hosts + hostAddr := c.config.networkAddress() + + addrs := []net.Addr{hostAddr} + + var errs []error + for _, addr := range addrs { + network, address := addr.Network(), addr.String() + + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{ + "network": network, + "address": address, + }) + } - // TODO: Start loop for all hosts [host0 .. hostN] - for { err = c.connect(config, network, address, config.TLSConfig) if err != nil && config.UseFallbackTLS { if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) } err = c.connect(config, network, address, config.FallbackTLSConfig) } if err != nil { if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) + c.log(LogLevelError, "connect failed", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) } - // TODO: Collect error + + errs = append(errs, err) continue } err = c.writeable() if err != nil { - // TODO: Log info about not writable host - // TODO: Collect error + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "host is not writable", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) + } + + errs = append(errs, err) continue } return c, nil } + // To keep backwards, if specific error type expected. + if len(errs) == 1 { + return nil, errs[0] + } - // TODO: Return collected errors - return nil, nil + var errmsg string + for _, err := range errs { + errmsg += ";" + err.Error() + } + + return nil, errors.New(errmsg) } func (c *Conn) writeable() error { @@ -331,7 +367,7 @@ func (c *Conn) writeable() error { Scan(st) if err != nil { - return errors.Wrap(err, "failed to fetch transaction_read_only state") + return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state") } if st == "on" { From 25e1f674a2a02cacc2db24924545539366fac825 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 16 Jun 2019 14:36:54 +0300 Subject: [PATCH 29/70] Fix doCancel with addr from networkAddress Signed-off-by: Artemiy Ryabinkov --- conn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 323875a5..287ce7db 100644 --- a/conn.go +++ b/conn.go @@ -1741,8 +1741,8 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - network, address := c.config.networkAddress() - cancelConn, err := c.config.Dial(network, address) + addr := c.config.networkAddress() + cancelConn, err := c.config.Dial(addr.Network(), addr.String()) if err != nil { return err } From 6ec815a7489b64856307091567f59635b9a65bbe Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Tue, 18 Jun 2019 16:02:09 +0300 Subject: [PATCH 30/70] Support Multiple Hosts in ConnConfig Signed-off-by: Artemiy Ryabinkov --- conn.go | 146 ++++++++++++++++++++++++++++++------ conn_config_test.go.example | 3 + conn_config_test.go.travis | 2 + conn_test.go | 101 +++++++++++++++++++++++++ pgpass.go | 1 + 5 files changed, 232 insertions(+), 21 deletions(-) diff --git a/conn.go b/conn.go index 287ce7db..a6ea78f1 100644 --- a/conn.go +++ b/conn.go @@ -63,10 +63,29 @@ type DialFunc func(network, addr string) (net.Conn, error) // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name + // Name of host to connect to. (e.g. localhost) + // If a host name begins with a slash, it specifies Unix-domain communication + // rather than TCP/IP communication; the value is the name of the directory + // in which the socket file is stored. (e.g. /private/tmp) + // The default behavior when host is not specified, or is empty, is to connect to localhost. + // + // A comma-separated list of host names is also accepted, + // in which case each host name in the list is tried in order; + // an empty item in the list selects the default behavior as explained above. + // @see https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS + Host string + + // Port number to connect to at the server host, + // or socket file name extension for Unix-domain connections. + // An empty or zero value, specifies the default port number — 5432. + // + // If multiple hosts were given in the Host parameter, then + // this parameter may specify a single port number to be used for all hosts, + // or for those that haven't port explicitly defined. + Port uint16 + Database string + User string // default: OS user name + // TODO: Allow password to be different for each host/port pair if a password file is used Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa @@ -95,10 +114,34 @@ type ConnConfig struct { // "read-write", to disallow connections to read-only servers, hot standbys for example. // @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com // @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/ + // + // The query SHOW transaction_read_only will be sent upon any successful connection; + // if it returns on, the connection will be closed. + // If multiple hosts were specified in the connection string, + // any remaining servers will be tried just as if the connection attempt had failed. + // The default value of this parameter, any, regards all connections as acceptable. TargetSessionAttrs string } -func (cc *ConnConfig) networkAddress() net.Addr { +// hostAddr represents network end point defined as hostname or IP + port. +type hostAddr struct { + Host string + Port uint16 +} + +// Network returns the address's network name, "tcp". +func (a *hostAddr) Network() string { return "tcp" } + +// String implements net.Addr String method. +func (a *hostAddr) String() string { + if a == nil { + return "" + } + + return net.JoinHostPort(a.Host, strconv.Itoa(int(a.Port))) +} + +func (cc *ConnConfig) networkAddresses() ([]net.Addr, error) { // See if host is a valid path, if yes connect with a unix socket if _, err := os.Stat(cc.Host); err == nil { // For backward compatibility accept socket file paths -- but directories are now preferred @@ -109,13 +152,50 @@ func (cc *ConnConfig) networkAddress() net.Addr { address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10) } - return &net.UnixAddr{Name: address, Net: network} + addrs := []net.Addr{ + &net.UnixAddr{Name: address, Net: network}, + } + + return addrs, nil } - return &net.TCPAddr{ - Port: int(cc.Port), - IP: net.ParseIP(cc.Host), + if cc.Host == "" { + addrs := []net.Addr{ + &net.TCPAddr{Port: int(cc.Port)}, + } + + return addrs, nil } + + var addrs []net.Addr + + hostports := strings.Split(cc.Host, ",") + for i, hostport := range hostports { + if hostport == "" { + return nil, fmt.Errorf("multi-host part %d is empty, at least host or port must be defined", i) + } + + // It's not possible to use net.TCPAddr here, cuz host may be hostname. + addr := hostAddr{ + Host: hostport, + Port: cc.Port, + } + + pos := strings.IndexByte(hostport, ':') + if pos != -1 { + p, err := strconv.ParseUint(hostport[pos+1:], 10, 16) + if err != nil { + return nil, fmt.Errorf("multi-host part %d (%s) has invalid port format", i, hostport) + } + + addr.Host = hostport[:pos] + addr.Port = uint16(p) + } + + addrs = append(addrs, &addr) + } + + return addrs, nil } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -156,6 +236,10 @@ type Conn struct { ConnInfo *pgtype.ConnInfo frontend *pgproto3.Frontend + + // In case of Multiple Hosts we need to know what addr was used to connect. + // This address will be used to send a cancellation request. + addr net.Addr } // PreparedStatement is a description of a prepared statement @@ -286,10 +370,10 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.config.Dial = d.Dial } - // TODO: Parse multi-hosts - hostAddr := c.config.networkAddress() - - addrs := []net.Addr{hostAddr} + addrs, err := c.config.networkAddresses() + if err != nil { + return nil, err + } var errs []error for _, addr := range addrs { @@ -323,12 +407,23 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) }) } + // On any auth errors return immediately + if pgErr, ok := err.(PgError); ok { + switch pgErr.Code { + // @see: https://www.postgresql.org/docs/current/errcodes-appendix.html + case "28000", "28P01": // Invalid Authorization Specification + return nil, pgErr + } + } + errs = append(errs, err) continue } - err = c.writeable() + err = c.writable() if err != nil { + c.die(err) + if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "host is not writable", map[string]interface{}{ "err": err, @@ -341,6 +436,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) continue } + c.addr = addr + return c, nil } @@ -351,30 +448,34 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) var errmsg string for _, err := range errs { - errmsg += ";" + err.Error() + errmsg += "; " + err.Error() } return nil, errors.New(errmsg) } -func (c *Conn) writeable() error { +func (c *Conn) writable() error { if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { return nil } var st string err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}). - Scan(st) + Scan(&st) if err != nil { return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state") } - if st == "on" { + switch st { + case "on": return errors.New("writable connection disabled by server") + case "off": + // If transaction_read_only = off, then connection is writable. + return nil } - return nil + return errors.New("unexpected \"transaction_read_only\" status") } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { @@ -958,6 +1059,8 @@ func ParseDSN(s string) (ConnConfig, error) { // ParseConnectionString parses either a URI or a DSN connection string. // see ParseURI and ParseDSN for details. func ParseConnectionString(s string) (ConnConfig, error) { + // TODO: Multiple Hosts support + // @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS if u, err := url.Parse(s); err == nil && u.Scheme != "" { return ParseURI(s) } @@ -981,6 +1084,8 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT +// TODO: PGTARGETSESSIONATTRS support +// @see: https://www.postgresql.org/docs/10/libpq-envars.html // // Important TLS Security Notes: // ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This @@ -1741,8 +1846,7 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - addr := c.config.networkAddress() - cancelConn, err := c.config.Dial(addr.Network(), addr.String()) + cancelConn, err := c.config.Dial(c.addr.Network(), c.addr.String()) if err != nil { return err } diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 096e1354..620b0ea1 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -7,6 +7,7 @@ import ( // "go/build" // "io/ioutil" // "path" + // "net" "github.com/jackc/pgx" ) @@ -14,6 +15,7 @@ import ( var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // To skip tests for specific connection / authentication types set that connection param to nil +var multihostConnConfig *pgx.ConnConfig = nil var tcpConnConfig *pgx.ConnConfig = nil var unixSocketConnConfig *pgx.ConnConfig = nil var md5ConnConfig *pgx.ConnConfig = nil @@ -24,6 +26,7 @@ var customDialerConnConfig *pgx.ConnConfig = nil var replicationConnConfig *pgx.ConnConfig = nil var cratedbConnConfig *pgx.ConnConfig = nil +// var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} // var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index cf29a743..738f1112 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -5,9 +5,11 @@ import ( "github.com/jackc/pgx" "os" "strconv" + "net" ) var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"} var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_test.go b/conn_test.go index 6ca00c6d..14efbeca 100644 --- a/conn_test.go +++ b/conn_test.go @@ -84,6 +84,107 @@ func TestConnect(t *testing.T) { } } + +func TestConnectWithMultiHost(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + conn, err := pgx.Connect(*multihostConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err = conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + + +func TestConnectWithMultiHostWritable(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + connConfig := *multihostConnConfig + connConfig.TargetSessionAttrs = "read-write" + + conn := mustConnect(t, connConfig) + defer closeConn(t, conn) + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err := conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + var st string + err = conn.QueryRow("SHOW transaction_read_only").Scan(&st) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + + if st == "on" { + t.Error("Connection is not writable") + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithUnixSocketDirectory(t *testing.T) { t.Parallel() diff --git a/pgpass.go b/pgpass.go index 34b9bdf5..ff97e5f0 100644 --- a/pgpass.go +++ b/pgpass.go @@ -57,6 +57,7 @@ func parsepgpass(line, cfgHost, cfgPort, cfgDatabase, cfgUsername string) *strin return &parts[4] } +// TODO: Multi-host support func pgpass(cfg *ConnConfig) (found bool) { passfile := os.Getenv("PGPASSFILE") if passfile == "" { From 2837818b67f39d5dfb49a6b37f7d8a23ef263896 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Tue, 18 Jun 2019 17:09:38 +0300 Subject: [PATCH 31/70] fix typo Signed-off-by: Artemiy Ryabinkov --- conn.go | 2 +- conn_config_test.go.example | 1 + conn_config_test.go.travis | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index a6ea78f1..153c7a3d 100644 --- a/conn.go +++ b/conn.go @@ -441,7 +441,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) return c, nil } - // To keep backwards, if specific error type expected. + // To keep backwards compatibility, if specific error type expected. if len(errs) == 1 { return nil, errs[0] } diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 620b0ea1..2ca84ac3 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -8,6 +8,7 @@ import ( // "io/ioutil" // "path" // "net" + // "time" "github.com/jackc/pgx" ) diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index 738f1112..fbfb5252 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -6,6 +6,7 @@ import ( "os" "strconv" "net" + "time" ) var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} From 39b09f2c4af27a3ed19ab8def20ea094df8f65e5 Mon Sep 17 00:00:00 2001 From: jinhua luo Date: Tue, 25 Jun 2019 01:05:28 +0800 Subject: [PATCH 32/70] cast bytea to make []byte suitable for both string and binary string types --- internal/sanitize/sanitize.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 53543b89..8939d797 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -87,7 +87,7 @@ func QuoteString(str string) string { } func QuoteBytes(buf []byte) string { - return `'\x` + hex.EncodeToString(buf) + "'" + return `'\x` + hex.EncodeToString(buf) + "'::bytea" } type sqlLexer struct { From e07faf207d03a337183040882337aa0d7dd2e743 Mon Sep 17 00:00:00 2001 From: jinhua luo Date: Tue, 25 Jun 2019 02:12:56 +0800 Subject: [PATCH 33/70] adjust the test for the patch --- internal/sanitize/sanitize_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index 9597840e..f4337253 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -108,7 +108,7 @@ func TestQuerySanitize(t *testing.T) { { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []interface{}{[]byte{0, 1, 2, 3, 255}}, - expected: `select '\x00010203ff'`, + expected: `select '\x00010203ff'::bytea`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, From 134d3e8d7e9de0443bc9e5094667cdb4275f71e7 Mon Sep 17 00:00:00 2001 From: Nick Jones Date: Tue, 25 Jun 2019 12:37:52 +1000 Subject: [PATCH 34/70] Read OIDs for composite types on connection init. This used to be done, but pulled in tables which slowed down connections on databases with a large number of tables; see https://github.com/jackc/pgx/issues/140. This change includes composite types but excludes tables by joining against [pg_class](https://www.postgresql.org/docs/11/catalog-pg-class.html) in which `relkind` is `'c'` for the former and `'r'` for the latter. Fixes https://github.com/jackc/pgx/issues/420. --- conn.go | 4 +++- pgmock/pgmock.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index cb24748c..0cf6c167 100644 --- a/conn.go +++ b/conn.go @@ -404,9 +404,11 @@ func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { from pg_type t left join pg_type base_type on t.typelem=base_type.oid left join pg_namespace nsp on t.typnamespace=nsp.oid +left join pg_class cls on t.typrelid=cls.oid where ( - t.typtype in('b', 'p', 'r', 'e') + t.typtype in('b', 'p', 'r', 'e', 'c') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) + and (cls.oid is null or cls.relkind='c') )` ) diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 4d15f7b8..d4ab0d13 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -211,9 +211,11 @@ func PgxInitSteps() []Step { from pg_type t left join pg_type base_type on t.typelem=base_type.oid left join pg_namespace nsp on t.typnamespace=nsp.oid +left join pg_class cls on t.typrelid=cls.oid where ( - t.typtype in('b', 'p', 'r', 'e') + t.typtype in('b', 'p', 'r', 'e', 'c') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) + and (cls.oid is null or cls.relkind='c') )`, }), ExpectMessage(&pgproto3.Describe{ From c474426c1143f927738bfa077ecf10d888cc0d97 Mon Sep 17 00:00:00 2001 From: Euan Kemp Date: Tue, 25 Jun 2019 21:40:32 -0700 Subject: [PATCH 35/70] Log error message on rows-close error --- query.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/query.go b/query.go index 27969be9..14df853f 100644 --- a/query.go +++ b/query.go @@ -84,7 +84,7 @@ func (rows *Rows) Close() { rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) } } else if rows.conn.shouldLog(LogLevelError) { - rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) + rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "err": rows.err}) } if rows.batch != nil && rows.err != nil { From c5be74ca4e2f5f26b00bd6f02c98f43df6b8407b Mon Sep 17 00:00:00 2001 From: jinhua luo Date: Thu, 27 Jun 2019 13:16:35 +0800 Subject: [PATCH 36/70] send simple query if no args no need to parse and sanitize the sql string when no args. --- query.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/query.go b/query.go index 14df853f..5c6cbf7f 100644 --- a/query.go +++ b/query.go @@ -522,6 +522,10 @@ func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { } func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { + if len(args) == 0 { + return c.sendSimpleQuery(sql) + } + if c.RuntimeParams["standard_conforming_strings"] != "on" { return errors.New("simple protocol queries must be run with standard_conforming_strings=on") } From a1d6202434aa40c3624688f6c2cacbc27eef5472 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Jun 2019 11:19:29 -0500 Subject: [PATCH 37/70] Release 3.5.0 --- CHANGELOG.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 27c7d9df..f22d8d29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,21 @@ +# 3.5.0 (June 29, 2019) + +## Features + +* Protocol support for PortalSuspended message (avivklas) +* Read OIDs for composite types on connection init (Nick Jones) + +## Fixes + +* Hstore can have empty keys (Josh Leverette) +* Fix -0 value for numeric type (David Hudson) +* Log error message on rows-close error (Euan Kemp) + +## Changes + +* Explicitly cast binary string to bytea in simple protocol (jinhua luo) +* Skip parse and sanitize simple query when no arguments (jinhua luo) + # 3.4.0 (May 3, 2019) ## Features From 7c5d801f058d3a2c41650f9643f2873a10d98964 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Jun 2019 14:13:10 -0500 Subject: [PATCH 38/70] Add v4 prerelease notice --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index d6499ba4..b7051f65 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,10 @@ if err != nil { } ``` +## v4 Coming Soon + +This is the current stable v3 version. v4 is currently is in prelease status. Consider using [v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications. + ## Features pgx supports many additional features beyond what is available through database/sql. From 8ba5485db6ebfa824458df98db6df29cbe128332 Mon Sep 17 00:00:00 2001 From: Nicholas Wilson Date: Thu, 4 Jul 2019 10:21:32 +0100 Subject: [PATCH 39/70] Use zap.Any for handling interface{} -> zap.Field conversion zap.Any falls back to zap.Reflect, but is better for this case, because it first checks for the types that zap handles specially. For example, time.Duration, or error, which zap.Reflect will just treat as untyped int64 or struct objects, but zap.Any is able to detect these types and print them properly. --- log/zapadapter/adapter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/log/zapadapter/adapter.go b/log/zapadapter/adapter.go index 82263b6e..a5a377e6 100644 --- a/log/zapadapter/adapter.go +++ b/log/zapadapter/adapter.go @@ -19,7 +19,7 @@ func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{ fields := make([]zapcore.Field, len(data)) i := 0 for k, v := range data { - fields[i] = zap.Reflect(k, v) + fields[i] = zap.Any(k, v) i++ } From bdac37aedb237933e817facfb4c287ae93852485 Mon Sep 17 00:00:00 2001 From: David Date: Fri, 5 Jul 2019 10:09:57 -0700 Subject: [PATCH 40/70] Registers composite types as a `pgtype.Record`. --- conn.go | 45 ++++++++++++++++++++++++++++++++++++++++++--- pgmock/pgmock.go | 4 +--- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index 0cf6c167..b613707e 100644 --- a/conn.go +++ b/conn.go @@ -404,11 +404,9 @@ func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { from pg_type t left join pg_type base_type on t.typelem=base_type.oid left join pg_namespace nsp on t.typnamespace=nsp.oid -left join pg_class cls on t.typrelid=cls.oid where ( - t.typtype in('b', 'p', 'r', 'e', 'c') + t.typtype in('b', 'p', 'r', 'e') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) - and (cls.oid is null or cls.relkind='c') )` ) @@ -428,6 +426,10 @@ where ( return nil, err } + if err = c.initConnInfoComposite(cinfo); err != nil { + return nil, err + } + return cinfo, nil } @@ -540,6 +542,43 @@ where t.typtype = 'd' return nil } +func (c *Conn) initConnInfoComposite(cinfo *pgtype.ConnInfo) error { + nameOIDs := make(map[string]pgtype.OID, 1) + + rows, err := c.Query(`select t.oid, t.typname +from pg_type t + join pg_class cls on t.typrelid=cls.oid +where t.typtype = 'c' + and cls.relkind='c'`) + if err != nil { + return err + } + + for rows.Next() { + var oid pgtype.OID + var name pgtype.Text + if err := rows.Scan(&oid, &name); err != nil { + return err + } + + nameOIDs[name.String] = oid + } + + if rows.Err() != nil { + return rows.Err() + } + + for name, oid := range nameOIDs { + cinfo.RegisterDataType(pgtype.DataType{ + Value: &pgtype.Record{}, + Name: name, + OID: oid, + }) + } + + return nil +} + // crateDBTypesQuery checks if the given err is likely to be the result of // CrateDB not implementing the pg_types table correctly. If yes, a CrateDB // specific query against pg_types is executed and its results are returned. If diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index d4ab0d13..4d15f7b8 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -211,11 +211,9 @@ func PgxInitSteps() []Step { from pg_type t left join pg_type base_type on t.typelem=base_type.oid left join pg_namespace nsp on t.typnamespace=nsp.oid -left join pg_class cls on t.typrelid=cls.oid where ( - t.typtype in('b', 'p', 'r', 'e', 'c') + t.typtype in('b', 'p', 'r', 'e') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) - and (cls.oid is null or cls.relkind='c') )`, }), ExpectMessage(&pgproto3.Describe{ From bcb2afe2be3d755f0ca53f3df0b262f3ca64096f Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Wed, 10 Jul 2019 22:59:17 +0300 Subject: [PATCH 41/70] TargetSessionAttrs as custom type Signed-off-by: Artemiy Ryabinkov --- conn.go | 28 ++++++++++++++++++++++------ conn_test.go | 2 +- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index 153c7a3d..f0cf20be 100644 --- a/conn.go +++ b/conn.go @@ -61,6 +61,24 @@ type NoticeHandler func(*Conn, *Notice) // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) +// TargetSessionType represents target session attrs configuration parameter. +type TargetSessionType string + +// Block enumerates available values for TargetSessionType. +const ( + AnyTargetSession = "any" + ReadWriteTargetSession = "read-write" +) + +func (t TargetSessionType) isValid() error { + switch t { + case "", AnyTargetSession, ReadWriteTargetSession: + return nil + } + + return errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") +} + // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { // Name of host to connect to. (e.g. localhost) @@ -120,7 +138,7 @@ type ConnConfig struct { // If multiple hosts were specified in the connection string, // any remaining servers will be tried just as if the connection attempt had failed. // The default value of this parameter, any, regards all connections as acceptable. - TargetSessionAttrs string + TargetSessionAttrs TargetSessionType } // hostAddr represents network end point defined as hostname or IP + port. @@ -357,10 +375,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } } - if c.config.TargetSessionAttrs != "" && - c.config.TargetSessionAttrs != "any" && - c.config.TargetSessionAttrs != "read-write" { - return nil, errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") + if err := c.config.TargetSessionAttrs.isValid(); err != nil { + return nil, err } c.onNotice = config.OnNotice @@ -455,7 +471,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } func (c *Conn) writable() error { - if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { + if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == AnyTargetSession { return nil } diff --git a/conn_test.go b/conn_test.go index 14efbeca..28bfe48b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -138,7 +138,7 @@ func TestConnectWithMultiHostWritable(t *testing.T) { } connConfig := *multihostConnConfig - connConfig.TargetSessionAttrs = "read-write" + connConfig.TargetSessionAttrs = pgx.ReadWriteTargetSession conn := mustConnect(t, connConfig) defer closeConn(t, conn) From 7d4215cb88d63e43baa8b8735bdafbfbd673c8bd Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Wed, 10 Jul 2019 23:16:46 +0300 Subject: [PATCH 42/70] fix error message building from errors array on connection establishing Signed-off-by: Artemiy Ryabinkov --- conn.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index f0cf20be..e570249e 100644 --- a/conn.go +++ b/conn.go @@ -462,12 +462,12 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) return nil, errs[0] } - var errmsg string + errmsgs := make([]string, len(errs)) for _, err := range errs { - errmsg += "; " + err.Error() + errmsgs = append(errmsgs, err.Error()) } - return nil, errors.New(errmsg) + return nil, errors.New(strings.Join(errmsgs, ";")) } func (c *Conn) writable() error { From 75b4ba635c0224e046dc3c3d5e8d5d30c5b65d61 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 00:16:58 +0300 Subject: [PATCH 43/70] try to improve readability of writable checking Signed-off-by: Artemiy Ryabinkov --- conn.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index e570249e..975cf337 100644 --- a/conn.go +++ b/conn.go @@ -79,6 +79,10 @@ func (t TargetSessionType) isValid() error { return errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") } +func (t TargetSessionType) writableRequired() bool { + return t == ReadWriteTargetSession +} + // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { // Name of host to connect to. (e.g. localhost) @@ -436,7 +440,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) continue } - err = c.writable() + err = c.checkWritable() if err != nil { c.die(err) @@ -470,8 +474,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) return nil, errors.New(strings.Join(errmsgs, ";")) } -func (c *Conn) writable() error { - if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == AnyTargetSession { +func (c *Conn) checkWritable() error { + if !c.config.TargetSessionAttrs.writableRequired() { return nil } @@ -485,7 +489,7 @@ func (c *Conn) writable() error { switch st { case "on": - return errors.New("writable connection disabled by server") + return errors.New("writable transactions disabled by server") case "off": // If transaction_read_only = off, then connection is writable. return nil From 18189fafd54ca2b678681a0550d632d1a5434f2c Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 20:28:04 +0300 Subject: [PATCH 44/70] ParseConnectionString supports Multi-Hosts Signed-off-by: Artemiy Ryabinkov --- conn.go | 123 +++++++++++++++++++++++++++++++++++++++++---------- conn_test.go | 76 +++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+), 23 deletions(-) diff --git a/conn.go b/conn.go index 975cf337..fd134461 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net" "net/url" "os" @@ -947,16 +948,26 @@ func ParseURI(uri string) (ConnConfig, error) { cp.Password, _ = url.User.Password() } - parts := strings.SplitN(url.Host, ":", 2) - cp.Host = parts[0] - if len(parts) == 2 { - p, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return cp, err + hasMuliHosts := strings.IndexByte(url.Host, ',') != -1 + if !hasMuliHosts { + parts := strings.SplitN(url.Host, ":", 2) + cp.Host = parts[0] + if len(parts) == 2 { + p, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return cp, err + } + cp.Port = uint16(p) } - cp.Port = uint16(p) + } else { + cp.Host = url.Host } + cp.Database = strings.TrimLeft(url.Path, "/") + cp.TargetSessionAttrs = TargetSessionType(url.Query().Get("target_session_attrs")) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { timeout, err := strconv.ParseInt(pgtimeout, 10, 64) @@ -980,11 +991,12 @@ func ParseURI(uri string) (ConnConfig, error) { } ignoreKeys := map[string]struct{}{ - "connect_timeout": {}, - "sslcert": {}, - "sslkey": {}, - "sslmode": {}, - "sslrootcert": {}, + "connect_timeout": {}, + "sslcert": {}, + "sslkey": {}, + "sslmode": {}, + "sslrootcert": {}, + "target_session_attrs": {}, } cp.RuntimeParams = make(map[string]string) @@ -1029,6 +1041,7 @@ func ParseDSN(s string) (ConnConfig, error) { cp.RuntimeParams = make(map[string]string) + var hostval, portval string for _, b := range m { switch b[1] { case "user": @@ -1036,13 +1049,9 @@ func ParseDSN(s string) (ConnConfig, error) { case "password": cp.Password = b[2] case "host": - cp.Host = b[2] + hostval = b[2] case "port": - p, err := strconv.ParseUint(b[2], 10, 16) - if err != nil { - return cp, err - } - cp.Port = uint16(p) + portval = b[2] case "dbname": cp.Database = b[2] case "sslmode": @@ -1061,26 +1070,94 @@ func ParseDSN(s string) (ConnConfig, error) { d := defaultDialer() d.Timeout = time.Duration(timeout) * time.Second cp.Dial = d.Dial + case "target_session_attrs": + cp.TargetSessionAttrs = TargetSessionType(b[2]) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } default: cp.RuntimeParams[b[1]] = b[2] } } - err := configTLS(tlsArgs, &cp) + host, port, err := parseHostPortDSN(hostval, portval) if err != nil { return cp, err } + + cp.Host, cp.Port = host, port + + err = configTLS(tlsArgs, &cp) + if err != nil { + return cp, err + } + if cp.Password == "" { pgpass(&cp) } + return cp, nil } -// ParseConnectionString parses either a URI or a DSN connection string. -// see ParseURI and ParseDSN for details. +func parseHostPortDSN(hostval, portval string) (host string, port uint16, err error) { + if portval == "" { + return hostval, 0, nil + } + + hosts := strings.Split(hostval, ",") + ports := strings.Split(portval, ",") + + if len(ports) == 1 { + port, err := parsePort(portval) + if err != nil { + return "", 0, errors.Errorf("invalid port: %v", err) + } + + return hostval, port, nil + } + + if len(hosts) != len(ports) { + return "", 0, errors.New("the number of hosts and ports must be the same") + } + + hostports := make([]string, len(hosts)) + for i, host := range hosts { + hostports[i] = host + ":" + ports[i] + } + + return strings.Join(hostports, ","), 0, nil +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +// ParseConnectionString parses either a URI or a DSN connection string and builds ConnConfig. +// +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// +// ParseConnectionString supports specifying multiple hosts in similar manner to libpq. +// Host and port may include comma separated values that will be tried in order. +// This can be used as part of a high availability system. +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// +// # Example DSN +// user=jack password=secret host=host1,host2,host3 port=5432,5433,5434 dbname=mydb sslmode=verify-ca func ParseConnectionString(s string) (ConnConfig, error) { - // TODO: Multiple Hosts support - // @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS if u, err := url.Parse(s); err == nil && u.Scheme != "" { return ParseURI(s) } diff --git a/conn_test.go b/conn_test.go index 28bfe48b..7719bec7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -622,6 +622,38 @@ func TestParseURI(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "foo.example.com:5432,bar.example.com:5432", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost,10.10.20.30/mydb?application_name=pgxtest&target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost,10.10.20.30", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + }, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { @@ -748,6 +780,50 @@ func TestParseDSN(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "user=jack host=localhost1,localhost2 dbname=mydb connect_timeout=10", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost1,localhost2", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=100.200.220.50,localhost43 port=5432,5433 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "100.200.220.50:5432,localhost43:5433", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { From 39cbdf789d3448c56ba394557e8100a143694c56 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 20:56:44 +0300 Subject: [PATCH 45/70] Support of PGTARGETSESSIONATTRS ENV variable Signed-off-by: Artemiy Ryabinkov --- conn.go | 8 +++++++- pgpass.go | 1 - 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index fd134461..76e09e73 100644 --- a/conn.go +++ b/conn.go @@ -1016,6 +1016,7 @@ func ParseURI(uri string) (ConnConfig, error) { if cp.Password == "" { pgpass(&cp) } + return cp, nil } @@ -1181,7 +1182,7 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT -// TODO: PGTARGETSESSIONATTRS support +// PGTARGETSESSIONATTRS // @see: https://www.postgresql.org/docs/10/libpq-envars.html // // Important TLS Security Notes: @@ -1228,6 +1229,11 @@ func ParseEnvLibpq() (ConnConfig, error) { } } + cc.TargetSessionAttrs = TargetSessionType(os.Getenv("PGTARGETSESSIONATTRS")) + if err := cc.TargetSessionAttrs.isValid(); err != nil { + return cc, err + } + tlsArgs := configTLSArgs{ sslMode: os.Getenv("PGSSLMODE"), sslKey: os.Getenv("PGSSLKEY"), diff --git a/pgpass.go b/pgpass.go index ff97e5f0..34b9bdf5 100644 --- a/pgpass.go +++ b/pgpass.go @@ -57,7 +57,6 @@ func parsepgpass(line, cfgHost, cfgPort, cfgDatabase, cfgUsername string) *strin return &parts[4] } -// TODO: Multi-host support func pgpass(cfg *ConnConfig) (found bool) { passfile := os.Getenv("PGPASSFILE") if passfile == "" { From f87825cac7b1ae3311a31a2093bcb00065667ba6 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 21:38:29 +0300 Subject: [PATCH 46/70] remove TODO that PR will not cover Signed-off-by: Artemiy Ryabinkov --- conn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/conn.go b/conn.go index 76e09e73..e195bb2f 100644 --- a/conn.go +++ b/conn.go @@ -108,7 +108,6 @@ type ConnConfig struct { Port uint16 Database string User string // default: OS user name - // TODO: Allow password to be different for each host/port pair if a password file is used Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa From 98acf573cce94af544d41b3e2bbfc9d86b7494cf Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 13 Jul 2019 21:21:23 +0300 Subject: [PATCH 47/70] fix errors collecting on multi-host Signed-off-by: Artemiy Ryabinkov --- conn.go | 4 +- examples/multihosts/README.md | 25 ++++++++++++ examples/multihosts/main.go | 74 +++++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 examples/multihosts/README.md create mode 100644 examples/multihosts/main.go diff --git a/conn.go b/conn.go index e195bb2f..7f2ae6a8 100644 --- a/conn.go +++ b/conn.go @@ -467,8 +467,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } errmsgs := make([]string, len(errs)) - for _, err := range errs { - errmsgs = append(errmsgs, err.Error()) + for i, err := range errs { + errmsgs[i] = err.Error() } return nil, errors.New(strings.Join(errmsgs, ";")) diff --git a/examples/multihosts/README.md b/examples/multihosts/README.md new file mode 100644 index 00000000..4b73eb51 --- /dev/null +++ b/examples/multihosts/README.md @@ -0,0 +1,25 @@ +# Description + +This is a sample chat program implemented using PostgreSQL's listen/notify +functionality with pgx. + +Start multiple instances of this program connected to the same database to chat +between them. + +## Connection configuration + +The database connection is configured via the standard PostgreSQL environment variables. + +* PGHOST - defaults to localhost +* PGUSER - defaults to current OS user +* PGPASSWORD - defaults to empty string +* PGDATABASE - defaults to user name + +You can either export them then run chat: + + export PGHOST=/private/tmp + ./chat + +Or you can prefix the chat execution with the environment variables: + + PGHOST=/private/tmp ./chat diff --git a/examples/multihosts/main.go b/examples/multihosts/main.go new file mode 100644 index 00000000..83b16c02 --- /dev/null +++ b/examples/multihosts/main.go @@ -0,0 +1,74 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + + "github.com/jackc/pgx" +) + +var pool *pgx.ConnPool + +func main() { + config, err := pgx.ParseEnvLibpq() + if err != nil { + fmt.Fprintln(os.Stderr, "Unable to parse environment:", err) + os.Exit(1) + } + + pool, err = pgx.NewConnPool(pgx.ConnPoolConfig{ConnConfig: config}) + if err != nil { + fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) + os.Exit(1) + } + + go listen() + + fmt.Println(`Type a message and press enter. + +This message should appear in any other chat instances connected to the same +database. + +Type "exit" to quit.`) + + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + msg := scanner.Text() + if msg == "exit" { + os.Exit(0) + } + + _, err = pool.Exec("select pg_notify('chat', $1)", msg) + if err != nil { + fmt.Fprintln(os.Stderr, "Error sending notification:", err) + os.Exit(1) + } + } + if err := scanner.Err(); err != nil { + fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err) + os.Exit(1) + } +} + +func listen() { + conn, err := pool.Acquire() + if err != nil { + fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) + os.Exit(1) + } + defer pool.Release(conn) + + conn.Listen("chat") + + for { + notification, err := conn.WaitForNotification(context.Background()) + if err != nil { + fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) + os.Exit(1) + } + + fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) + } +} From a2b647c393b3349c9ddf568d279e7fcd71520f88 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 13 Jul 2019 22:17:03 +0300 Subject: [PATCH 48/70] drop extra example Signed-off-by: Artemiy Ryabinkov --- conn.go | 2 +- examples/multihosts/README.md | 25 ------------ examples/multihosts/main.go | 74 ----------------------------------- 3 files changed, 1 insertion(+), 100 deletions(-) delete mode 100644 examples/multihosts/README.md delete mode 100644 examples/multihosts/main.go diff --git a/conn.go b/conn.go index 7f2ae6a8..01a572d7 100644 --- a/conn.go +++ b/conn.go @@ -471,7 +471,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) errmsgs[i] = err.Error() } - return nil, errors.New(strings.Join(errmsgs, ";")) + return nil, errors.New(strings.Join(errmsgs, "; ")) } func (c *Conn) checkWritable() error { diff --git a/examples/multihosts/README.md b/examples/multihosts/README.md deleted file mode 100644 index 4b73eb51..00000000 --- a/examples/multihosts/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# Description - -This is a sample chat program implemented using PostgreSQL's listen/notify -functionality with pgx. - -Start multiple instances of this program connected to the same database to chat -between them. - -## Connection configuration - -The database connection is configured via the standard PostgreSQL environment variables. - -* PGHOST - defaults to localhost -* PGUSER - defaults to current OS user -* PGPASSWORD - defaults to empty string -* PGDATABASE - defaults to user name - -You can either export them then run chat: - - export PGHOST=/private/tmp - ./chat - -Or you can prefix the chat execution with the environment variables: - - PGHOST=/private/tmp ./chat diff --git a/examples/multihosts/main.go b/examples/multihosts/main.go deleted file mode 100644 index 83b16c02..00000000 --- a/examples/multihosts/main.go +++ /dev/null @@ -1,74 +0,0 @@ -package main - -import ( - "bufio" - "context" - "fmt" - "os" - - "github.com/jackc/pgx" -) - -var pool *pgx.ConnPool - -func main() { - config, err := pgx.ParseEnvLibpq() - if err != nil { - fmt.Fprintln(os.Stderr, "Unable to parse environment:", err) - os.Exit(1) - } - - pool, err = pgx.NewConnPool(pgx.ConnPoolConfig{ConnConfig: config}) - if err != nil { - fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) - os.Exit(1) - } - - go listen() - - fmt.Println(`Type a message and press enter. - -This message should appear in any other chat instances connected to the same -database. - -Type "exit" to quit.`) - - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { - msg := scanner.Text() - if msg == "exit" { - os.Exit(0) - } - - _, err = pool.Exec("select pg_notify('chat', $1)", msg) - if err != nil { - fmt.Fprintln(os.Stderr, "Error sending notification:", err) - os.Exit(1) - } - } - if err := scanner.Err(); err != nil { - fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err) - os.Exit(1) - } -} - -func listen() { - conn, err := pool.Acquire() - if err != nil { - fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) - os.Exit(1) - } - defer pool.Release(conn) - - conn.Listen("chat") - - for { - notification, err := conn.WaitForNotification(context.Background()) - if err != nil { - fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) - os.Exit(1) - } - - fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) - } -} From 1ecc111e17995b5aba2e0b7b1fd57c616f9172a7 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 14 Jul 2019 18:29:08 +0300 Subject: [PATCH 49/70] Reuse pool.connInfo for createConnectionUnlocked method Signed-off-by: Artemiy Ryabinkov --- conn_pool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_pool.go b/conn_pool.go index 47a0b391..d43b6337 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -319,7 +319,7 @@ func (p *ConnPool) createConnection() (*Conn, error) { func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { p.inProgressConnects++ p.cond.L.Unlock() - c, err := Connect(p.config) + c, err := connect(p.config, p.connInfo) p.cond.L.Lock() p.inProgressConnects-- From 8e0e1123dfa4f7280ad56da42fa211bb91ea39f4 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 14 Jul 2019 20:04:55 +0300 Subject: [PATCH 50/70] use deepCopy of connInfo in createConnectionUnlocked method Signed-off-by: Artemiy Ryabinkov --- conn_pool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_pool.go b/conn_pool.go index d43b6337..e8972a0b 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -319,7 +319,7 @@ func (p *ConnPool) createConnection() (*Conn, error) { func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { p.inProgressConnects++ p.cond.L.Unlock() - c, err := connect(p.config, p.connInfo) + c, err := connect(p.config, p.connInfo.DeepCopy()) p.cond.L.Lock() p.inProgressConnects-- From fc020c24ac9590f6547f8ad1d291fc75b4873a84 Mon Sep 17 00:00:00 2001 From: Nicholas Wilson Date: Wed, 24 Jul 2019 12:32:18 +0100 Subject: [PATCH 51/70] Add support for pgtype.UUID to write into any [16]byte type --- pgtype/convert.go | 29 +++++++++++++++++++++++++++++ pgtype/uuid.go | 2 +- pgtype/uuid_test.go | 21 +++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/pgtype/convert.go b/pgtype/convert.go index 5dfb738e..ee6907c4 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -163,6 +163,27 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { return time.Time{}, false } +// underlyingUUIDType gets the underlying type that can be converted to [16]byte +func underlyingUUIDType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + uuidType := reflect.TypeOf([16]byte{}) + if refVal.Type().ConvertibleTo(uuidType) { + return refVal.Convert(uuidType).Interface(), true + } + + return nil, false +} + // underlyingSliceType gets the underlying slice type func underlyingSliceType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) @@ -401,6 +422,14 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { } } + if dstVal.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)) + nextDst := dstPtr.Convert(baseArrayType) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + } + return nil, false } diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 5e1eead5..8d33d8f8 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -39,7 +39,7 @@ func (dst *UUID) Set(src interface{}) error { } *dst = UUID{Bytes: uuid, Status: Present} default: - if originalSrc, ok := underlyingPtrType(src); ok { + if originalSrc, ok := underlyingUUIDType(src); ok { return dst.Set(originalSrc) } return errors.Errorf("cannot convert %v to UUID", value) diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 162d999f..1eddeda1 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -15,6 +15,8 @@ func TestUUIDTranscode(t *testing.T) { }) } +type SomeUUIDType [16]byte + func TestUUIDSet(t *testing.T) { successfulTests := []struct { source interface{} @@ -32,6 +34,10 @@ func TestUUIDSet(t *testing.T) { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, + { + source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, { source: ([]byte)(nil), result: pgtype.UUID{Status: pgtype.Null}, @@ -86,6 +92,21 @@ func TestUUIDAssignTo(t *testing.T) { } } + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst SomeUUIDType + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + { src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst string From 251e6b7730c7b31b600e6fe06162e541f3032604 Mon Sep 17 00:00:00 2001 From: Nicholas Wilson Date: Wed, 24 Jul 2019 12:32:43 +0100 Subject: [PATCH 52/70] Tidying: make underlyingTimeType consistent with other underlyingFooType The first return value is ignored when returning false - so there's no point returning an empty time.Time when it can be nil. --- pgtype/convert.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgtype/convert.go b/pgtype/convert.go index ee6907c4..029e3d48 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -149,7 +149,7 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { switch refVal.Kind() { case reflect.Ptr: if refVal.IsNil() { - return time.Time{}, false + return nil, false } convVal := refVal.Elem().Interface() return convVal, true @@ -160,7 +160,7 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { return refVal.Convert(timeType).Interface(), true } - return time.Time{}, false + return nil, false } // underlyingUUIDType gets the underlying type that can be converted to [16]byte From 92cd1ad639bf07d9395db46faecbbe73ac7d59ef Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Mon, 29 Jul 2019 21:19:36 +0300 Subject: [PATCH 53/70] Set 8KB as default size of ChunkReader buffer Signed-off-by: Artemiy Ryabinkov --- chunkreader/chunkreader.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go index f8d437b2..5c36292d 100644 --- a/chunkreader/chunkreader.go +++ b/chunkreader/chunkreader.go @@ -28,7 +28,11 @@ func NewChunkReader(r io.Reader) *ChunkReader { func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { if options.MinBufLen == 0 { - options.MinBufLen = 4096 + // By historical reasons Postgres currently has 8KB send buffer inside, + // so here we want to have at least the same size buffer. + // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 + // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru + options.MinBufLen = 8192 } return &ChunkReader{ From 95ea78048a9569250c078d1965a235a214239960 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Aug 2019 09:45:04 -0500 Subject: [PATCH 54/70] Remove 0 bytes when sanitizing identifiers fixes #562 --- conn.go | 9 +++++---- conn_test.go | 30 +++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/conn.go b/conn.go index 03f9d190..121297b8 100644 --- a/conn.go +++ b/conn.go @@ -105,9 +105,9 @@ type ConnConfig struct { // If multiple hosts were given in the Host parameter, then // this parameter may specify a single port number to be used for all hosts, // or for those that haven't port explicitly defined. - Port uint16 - Database string - User string // default: OS user name + Port uint16 + Database string + User string // default: OS user name Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa @@ -307,7 +307,8 @@ type Identifier []string func (ident Identifier) Sanitize() string { parts := make([]string, len(ident)) for i := range ident { - parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"` + s := strings.Replace(ident[i], string([]byte{0}), "", -1) + parts[i] = `"` + strings.Replace(s, `"`, `""`, -1) + `"` } return strings.Join(parts, ".") } diff --git a/conn_test.go b/conn_test.go index 7719bec7..fea3b659 100644 --- a/conn_test.go +++ b/conn_test.go @@ -84,7 +84,6 @@ func TestConnect(t *testing.T) { } } - func TestConnectWithMultiHost(t *testing.T) { t.Parallel() @@ -129,7 +128,6 @@ func TestConnectWithMultiHost(t *testing.T) { } } - func TestConnectWithMultiHostWritable(t *testing.T) { t.Parallel() @@ -818,9 +816,9 @@ func TestParseDSN(t *testing.T) { TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, TargetSessionAttrs: pgx.ReadWriteTargetSession, }, }, @@ -2319,6 +2317,24 @@ func TestSetLogLevel(t *testing.T) { } } +func TestIdentifierSanitizeNullSentToServer(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ident := pgx.Identifier{"foo" + string([]byte{0}) + "bar"} + + var n int64 + err := conn.QueryRow(`select 1 as ` + ident.Sanitize()).Scan(&n) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatal("unexpected n") + } +} + func TestIdentifierSanitize(t *testing.T) { t.Parallel() @@ -2346,6 +2362,10 @@ func TestIdentifierSanitize(t *testing.T) { ident: pgx.Identifier{`you should " not do this`, `please don't`}, expected: `"you should "" not do this"."please don't"`, }, + { + ident: pgx.Identifier{`you should ` + string([]byte{0}) + `not do this`}, + expected: `"you should not do this"`, + }, } for i, tt := range tests { From 7fe7f33557938739e5342d82d0720523c344eb71 Mon Sep 17 00:00:00 2001 From: "Andrew S. Brown" Date: Sun, 4 Aug 2019 15:31:32 -0700 Subject: [PATCH 55/70] Terminate context prior to releasing when killing batch connection --- batch.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/batch.go b/batch.go index 4b624387..8c924e8d 100644 --- a/batch.go +++ b/batch.go @@ -135,7 +135,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { _, err = b.conn.conn.Write(buf) if err != nil { - b.conn.die(err) + b.die(err) return err } @@ -281,10 +281,13 @@ func (b *Batch) die(err error) { } b.err = err - b.conn.die(err) + if b.conn != nil { + err = b.conn.termContext(err) + b.conn.die(err) - if b.conn != nil && b.connPool != nil { - b.connPool.Release(b.conn) + if b.connPool != nil { + b.connPool.Release(b.conn) + } } } From ca9de512569587bfd5f26ffd2a5e266a7bbfbef5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 6 Aug 2019 16:42:20 -0500 Subject: [PATCH 56/70] Return deferred errors Deferred errors are sent after the CommandComplete message. They could be silently dropped depending on the context in which it occurred. fixes #570 --- batch.go | 17 +++++++++++++++++ batch_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++++++ conn_test.go | 26 ++++++++++++++++++++++++++ query.go | 19 +++++++++++++++++++ query_test.go | 43 +++++++++++++++++++++++++++++++++++++++++- 5 files changed, 156 insertions(+), 1 deletion(-) diff --git a/batch.go b/batch.go index 8c924e8d..7f5422dc 100644 --- a/batch.go +++ b/batch.go @@ -268,6 +268,23 @@ func (b *Batch) Close() (err error) { } } + for b.conn.pendingReadyForQueryCount > 0 { + msg, err := b.conn.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return b.conn.rxErrorResponse(msg) + default: + err = b.conn.processContextFreeMsg(msg) + if err != nil { + return err + } + } + } + if err = b.conn.ensureConnectionReadyForQuery(); err != nil { return err } diff --git a/batch_test.go b/batch_test.go index 61bbe357..d0e26875 100644 --- a/batch_test.go +++ b/batch_test.go @@ -701,3 +701,55 @@ func TestTxBeginBatchRollback(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnBeginBatchDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + batch := conn.BeginBatch() + batch.Queue(`update t set n=n+1 where id='b' returning *`, + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + err = batch.Close() + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} diff --git a/conn_test.go b/conn_test.go index fea3b659..c6ce50cc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1370,6 +1370,32 @@ func TestExecFailure(t *testing.T) { } } +func TestExecDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + _, err := conn.Exec(`update t set n=n+1 where id='b'`) + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestExecFailureWithArguments(t *testing.T) { t.Parallel() diff --git a/query.go b/query.go index 5c6cbf7f..bf4ec561 100644 --- a/query.go +++ b/query.go @@ -69,6 +69,25 @@ func (rows *Rows) Close() { return } + // If there is no error and a batch operation is not in progress read until we get the ReadyForQuery message or the + // ErrorResponse. This is necessary to detect a deferred constraint violation where the ErrorResponse is sent after + // CommandComplete. + if rows.err == nil && rows.batch == nil && rows.conn.pendingReadyForQueryCount == 1 { + for rows.conn.pendingReadyForQueryCount > 0 { + msg, err := rows.conn.rxMsg() + if err != nil { + rows.err = err + break + } + + err = rows.conn.processContextFreeMsg(msg) + if err != nil { + rows.err = err + break + } + } + } + if rows.unlockConn { rows.conn.unlock() rows.unlockConn = false diff --git a/query_test.go b/query_test.go index 06b7b8b7..ea1fd66e 100644 --- a/query_test.go +++ b/query_test.go @@ -14,7 +14,7 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" - "github.com/satori/go.uuid" + uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" ) @@ -424,6 +424,47 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { } +// https://github.com/jackc/pgx/issues/570 +func TestConnQueryDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + rows, err := conn.Query(`update t set n=n+1 where id='b' returning *`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + if rows.Err() == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := rows.Err().(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestQueryEncodeError(t *testing.T) { t.Parallel() From 9e3f51e5c6759ee9d6eadfa33240b9503e39b096 Mon Sep 17 00:00:00 2001 From: Nathaniel Caza Date: Wed, 7 Aug 2019 13:49:34 -0500 Subject: [PATCH 57/70] Allow specifying LevelRepeatableRead --- stdlib/sql.go | 2 +- stdlib/sql_test.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index ec5933f3..e564152f 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -277,7 +277,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e pgxOpts.IsoLevel = pgx.ReadUncommitted case sql.LevelReadCommitted: pgxOpts.IsoLevel = pgx.ReadCommitted - case sql.LevelSnapshot: + case sql.LevelRepeatableRead, sql.LevelSnapshot: pgxOpts.IsoLevel = pgx.RepeatableRead case sql.LevelSerializable: pgxOpts.IsoLevel = pgx.Serializable diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index cf2b91b1..895ee583 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -629,6 +629,7 @@ func TestConnBeginTxIsolation(t *testing.T) { {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, + {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, } From 50b92ce0f591145c4b3da1c1e8fb9d98db845ebb Mon Sep 17 00:00:00 2001 From: Ian Stapleton Cordasco Date: Sun, 11 Aug 2019 08:16:48 -0500 Subject: [PATCH 58/70] Correct WaitForNotification example While working on a project that was using this, I tried using the example code but instead found that WaitForNotification expects a Context (which makes sense). This corrects the docs for folks using that as a jumping off point. --- doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc.go b/doc.go index 5808c09d..0c2b35d3 100644 --- a/doc.go +++ b/doc.go @@ -225,7 +225,7 @@ notification. return nil } - if notification, err := conn.WaitForNotification(time.Second); err != nil { + if notification, err := conn.WaitForNotification(context.TODO()); err != nil { // do something with notification } From 809600d6671eeea159f1560abff7af084d71f1a0 Mon Sep 17 00:00:00 2001 From: Jonathan Yoder Date: Thu, 15 Aug 2019 09:31:38 -0400 Subject: [PATCH 59/70] Clarify stdlib.AcquireConn Comment --- stdlib/sql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index e564152f..3cd2d941 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -43,8 +43,8 @@ // // AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard // database/sql.DB connection pool. This allows operations that must be -// performed on a single connection, but should not be run in a transaction or -// to use pgx specific functionality. +// performed on a single connection without running in a transaction, and it +// supports operations that use pgx specific functionality. // // conn, err := stdlib.AcquireConn(db) // if err != nil { From 7829081b8c1eebc860dab63378b60eb47456bea2 Mon Sep 17 00:00:00 2001 From: Dmitriy Garanzha Date: Fri, 16 Aug 2019 13:22:16 +0300 Subject: [PATCH 60/70] Load user-defined array type oids. --- conn.go | 2 +- pgmock/pgmock.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 121297b8..9dc4cdbf 100644 --- a/conn.go +++ b/conn.go @@ -615,7 +615,7 @@ left join pg_namespace nsp on t.typnamespace=nsp.oid left join pg_class cls on t.typrelid=cls.oid where ( t.typtype in('b', 'p', 'r', 'e', 'c') - and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) + and (base_type.oid is null or base_type.typtype in('b', 'p', 'r', 'c')) and (cls.oid is null or cls.relkind='c') )` ) diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index d4ab0d13..5c3fdc27 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -214,7 +214,7 @@ left join pg_namespace nsp on t.typnamespace=nsp.oid left join pg_class cls on t.typrelid=cls.oid where ( t.typtype in('b', 'p', 'r', 'e', 'c') - and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) + and (base_type.oid is null or base_type.typtype in('b', 'p', 'r', 'c')) and (cls.oid is null or cls.relkind='c') )`, }), From 12c6319244e4836c5bb6bbff2f786bf73487c574 Mon Sep 17 00:00:00 2001 From: Kale Blankenship Date: Wed, 28 Aug 2019 12:50:51 -0700 Subject: [PATCH 61/70] Include ParameterOIDs when preparing statements on new pool connections ParameterOIDs passed to ConnPool.PrepareEx are used to prepare the statement on existing connections in the pool. If additional connections are later created ParameterOIDs are omitted, potentially causing query failures. --- conn_pool.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/conn_pool.go b/conn_pool.go index e8972a0b..344f00d7 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -341,7 +341,8 @@ func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { } for _, ps := range p.preparedStatements { - if _, err := c.Prepare(ps.Name, ps.SQL); err != nil { + opts := &PrepareExOptions{ParameterOIDs: ps.ParameterOIDs} + if _, err := c.PrepareEx(context.Background(), ps.Name, ps.SQL, opts); err != nil { c.die(err) return nil, err } From 78f498fc43f957b2eccdac1d002798ee3c277a5c Mon Sep 17 00:00:00 2001 From: Kale Blankenship Date: Sat, 31 Aug 2019 10:27:19 -0700 Subject: [PATCH 62/70] Add ConnPool.AcquireEx --- conn_pool.go | 19 ++++++ conn_pool_test.go | 144 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+) diff --git a/conn_pool.go b/conn_pool.go index 344f00d7..95e1b015 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -110,6 +110,25 @@ func (p *ConnPool) Acquire() (*Conn, error) { return c, err } +func (p *ConnPool) AcquireEx(ctx context.Context) (*Conn, error) { + var deadline *time.Time + + if p.acquireTimeout > 0 { + tmp := time.Now().Add(p.acquireTimeout) + deadline = &tmp + } + + ctxDeadline, ok := ctx.Deadline() + if ok && (deadline == nil || ctxDeadline.Before(*deadline)) { + deadline = &ctxDeadline + } + + p.cond.L.Lock() + c, err := p.acquire(deadline) + p.cond.L.Unlock() + return c, err +} + // deadlinePassed returns true if the given deadline has passed. func (p *ConnPool) deadlinePassed(deadline *time.Time) bool { return deadline != nil && time.Now().After(*deadline) diff --git a/conn_pool_test.go b/conn_pool_test.go index 84a74aed..83bdf1fd 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -45,6 +45,12 @@ func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) return c, time.Since(startTime), err } +func acquireExWithTimeTaken(pool *pgx.ConnPool, ctx context.Context) (*pgx.Conn, time.Duration, error) { + startTime := time.Now() + c, err := pool.AcquireEx(ctx) + return c, time.Since(startTime), err +} + func TestNewConnPool(t *testing.T) { t.Parallel() @@ -315,6 +321,144 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { } } +func TestPoolWithAcquireExContextTimeoutSet(t *testing.T) { + t.Parallel() + + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 2 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < ctxTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) + } +} + +func TestPoolWithAcquireExPoolTimeoutLower(t *testing.T) { + t.Parallel() + + connAllocTimeout := 2 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 5 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < connAllocTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) + } + if timeTaken > ctxTimeout { + t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", ctxTimeout, timeTaken) + } +} + +func TestPoolWithAcquireExPoolTimeoutHigher(t *testing.T) { + t.Parallel() + + connAllocTimeout := 5 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 2 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < ctxTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) + } + if timeTaken > connAllocTimeout { + t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", connAllocTimeout, timeTaken) + } +} + +func TestPoolWithoutAcquireExTimeoutSet(t *testing.T) { + t.Parallel() + + maxConnections := 1 + pool := createConnPool(t, maxConnections) + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, maxConnections) + + // ... then try to consume 1 more. It should hang forever. + // To unblock it we release the previously taken connection in a goroutine. + stopDeadWaitTimeout := 5 * time.Second + timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() { + releaseAllConnections(pool, allConnections) + }) + defer timer.Stop() + + conn, timeTaken, err := acquireExWithTimeTaken(pool, context.Background()) + if err == nil { + pool.Release(conn) + } else { + t.Fatalf("Expected error to be nil, instead it was '%v'", err) + } + if timeTaken < stopDeadWaitTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) + } +} + func TestPoolErrClosedPool(t *testing.T) { t.Parallel() From 35908df25f9079a73270eefb1cf4c3df635ee876 Mon Sep 17 00:00:00 2001 From: Dmitriy Garanzha Date: Mon, 2 Sep 2019 16:57:21 +0300 Subject: [PATCH 63/70] Filter automatically created table array types. --- conn.go | 2 ++ pgmock/pgmock.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/conn.go b/conn.go index 9dc4cdbf..b98434f7 100644 --- a/conn.go +++ b/conn.go @@ -611,12 +611,14 @@ func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { end from pg_type t left join pg_type base_type on t.typelem=base_type.oid +left join pg_class base_cls ON base_type.typrelid = base_cls.oid left join pg_namespace nsp on t.typnamespace=nsp.oid left join pg_class cls on t.typrelid=cls.oid where ( t.typtype in('b', 'p', 'r', 'e', 'c') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r', 'c')) and (cls.oid is null or cls.relkind='c') + and (base_cls.oid is null or base_cls.relkind = 'c') )` ) diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 5c3fdc27..7b9e7991 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -210,12 +210,14 @@ func PgxInitSteps() []Step { end from pg_type t left join pg_type base_type on t.typelem=base_type.oid +left join pg_class base_cls ON base_type.typrelid = base_cls.oid left join pg_namespace nsp on t.typnamespace=nsp.oid left join pg_class cls on t.typrelid=cls.oid where ( t.typtype in('b', 'p', 'r', 'e', 'c') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r', 'c')) and (cls.oid is null or cls.relkind='c') + and (base_cls.oid is null or base_cls.relkind = 'c') )`, }), ExpectMessage(&pgproto3.Describe{ From f26e4c0e6921395ee2556c61c0152b031254ff6c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 2 Sep 2019 12:19:55 -0500 Subject: [PATCH 64/70] Update status of v4 --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b7051f65..0a4cacc3 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,8 @@ if err != nil { ## v4 Coming Soon -This is the current stable v3 version. v4 is currently is in prelease status. Consider using [v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications. +This is the current stable v3 version. v4 is currently is in release candidate status. Consider using +[v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications. ## Features From 2d89e52d6f20255fd98982260d2ea875583161fe Mon Sep 17 00:00:00 2001 From: David Date: Mon, 9 Sep 2019 10:52:31 -0700 Subject: [PATCH 65/70] Add composite registering to init steps. --- pgmock/pgmock.go | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 4d15f7b8..4a64b506 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -528,6 +528,47 @@ where ( SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), }...) + steps = append(steps, []Step{ + ExpectMessage(&pgproto3.Parse{ + Query: "select t.oid, t.typname\nfrom pg_type t\n\tjoin pg_class cls on t.typrelid=cls.oid\nwhere t.typtype = 'c'\n\tand cls.relkind='c'", + }), + ExpectMessage(&pgproto3.Describe{ + ObjectType: 'S', + }), + ExpectMessage(&pgproto3.Sync{}), + SendMessage(&pgproto3.ParseComplete{}), + SendMessage(&pgproto3.ParameterDescription{}), + SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: "oid", + TableOID: 1247, + TableAttributeNumber: 65534, + DataTypeOID: 26, + DataTypeSize: 4, + TypeModifier: -1, + Format: 0, + }, + {Name: "typname", + TableOID: 1247, + TableAttributeNumber: 1, + DataTypeOID: 19, + DataTypeSize: 64, + TypeModifier: -1, + Format: 0, + }, + }, + }), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + ExpectMessage(&pgproto3.Bind{ + ResultFormatCodes: []int16{1, 1}, + }), + ExpectMessage(&pgproto3.Execute{}), + ExpectMessage(&pgproto3.Sync{}), + SendMessage(&pgproto3.BindComplete{}), + SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 0"}), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }...) + return steps } From fed099f04a46a5ec837d8ce96c461d6d16e5fb21 Mon Sep 17 00:00:00 2001 From: David Date: Mon, 9 Sep 2019 11:26:13 -0700 Subject: [PATCH 66/70] Merge branch 'master' into composite --- README.md | 3 +- batch.go | 28 ++- batch_test.go | 52 +++++ chunkreader/chunkreader.go | 6 +- conn.go | 406 +++++++++++++++++++++++++++++++----- conn_config_test.go.example | 4 + conn_config_test.go.travis | 3 + conn_pool.go | 24 ++- conn_pool_test.go | 144 +++++++++++++ conn_test.go | 223 ++++++++++++++++++++ doc.go | 2 +- pgmock/pgmock.go | 1 + pgtype/convert.go | 33 ++- pgtype/uuid.go | 2 +- pgtype/uuid_test.go | 21 ++ query.go | 19 ++ query_test.go | 43 +++- stdlib/sql.go | 6 +- stdlib/sql_test.go | 1 + 19 files changed, 951 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index b7051f65..0a4cacc3 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,8 @@ if err != nil { ## v4 Coming Soon -This is the current stable v3 version. v4 is currently is in prelease status. Consider using [v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications. +This is the current stable v3 version. v4 is currently is in release candidate status. Consider using +[v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications. ## Features diff --git a/batch.go b/batch.go index 4b624387..7f5422dc 100644 --- a/batch.go +++ b/batch.go @@ -135,7 +135,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { _, err = b.conn.conn.Write(buf) if err != nil { - b.conn.die(err) + b.die(err) return err } @@ -268,6 +268,23 @@ func (b *Batch) Close() (err error) { } } + for b.conn.pendingReadyForQueryCount > 0 { + msg, err := b.conn.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return b.conn.rxErrorResponse(msg) + default: + err = b.conn.processContextFreeMsg(msg) + if err != nil { + return err + } + } + } + if err = b.conn.ensureConnectionReadyForQuery(); err != nil { return err } @@ -281,10 +298,13 @@ func (b *Batch) die(err error) { } b.err = err - b.conn.die(err) + if b.conn != nil { + err = b.conn.termContext(err) + b.conn.die(err) - if b.conn != nil && b.connPool != nil { - b.connPool.Release(b.conn) + if b.connPool != nil { + b.connPool.Release(b.conn) + } } } diff --git a/batch_test.go b/batch_test.go index 61bbe357..d0e26875 100644 --- a/batch_test.go +++ b/batch_test.go @@ -701,3 +701,55 @@ func TestTxBeginBatchRollback(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnBeginBatchDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + batch := conn.BeginBatch() + batch.Queue(`update t set n=n+1 where id='b' returning *`, + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + err = batch.Close() + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go index f8d437b2..5c36292d 100644 --- a/chunkreader/chunkreader.go +++ b/chunkreader/chunkreader.go @@ -28,7 +28,11 @@ func NewChunkReader(r io.Reader) *ChunkReader { func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { if options.MinBufLen == 0 { - options.MinBufLen = 4096 + // By historical reasons Postgres currently has 8KB send buffer inside, + // so here we want to have at least the same size buffer. + // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 + // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru + options.MinBufLen = 8192 } return &ChunkReader{ diff --git a/conn.go b/conn.go index b613707e..b4469af9 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net" "net/url" "os" @@ -61,10 +62,50 @@ type NoticeHandler func(*Conn, *Notice) // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) +// TargetSessionType represents target session attrs configuration parameter. +type TargetSessionType string + +// Block enumerates available values for TargetSessionType. +const ( + AnyTargetSession = "any" + ReadWriteTargetSession = "read-write" +) + +func (t TargetSessionType) isValid() error { + switch t { + case "", AnyTargetSession, ReadWriteTargetSession: + return nil + } + + return errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") +} + +func (t TargetSessionType) writableRequired() bool { + return t == ReadWriteTargetSession +} + // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 + // Name of host to connect to. (e.g. localhost) + // If a host name begins with a slash, it specifies Unix-domain communication + // rather than TCP/IP communication; the value is the name of the directory + // in which the socket file is stored. (e.g. /private/tmp) + // The default behavior when host is not specified, or is empty, is to connect to localhost. + // + // A comma-separated list of host names is also accepted, + // in which case each host name in the list is tried in order; + // an empty item in the list selects the default behavior as explained above. + // @see https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS + Host string + + // Port number to connect to at the server host, + // or socket file name extension for Unix-domain connections. + // An empty or zero value, specifies the default port number — 5432. + // + // If multiple hosts were given in the Host parameter, then + // this parameter may specify a single port number to be used for all hosts, + // or for those that haven't port explicitly defined. + Port uint16 Database string User string // default: OS user name Password string @@ -89,22 +130,94 @@ type ConnConfig struct { // used by default. The same functionality can be controlled on a per query // basis by setting QueryExOptions.SimpleProtocol. PreferSimpleProtocol bool + + // TargetSessionAttr allows to specify which servers are accepted for this connection. + // "any", meaning that any kind of servers can be accepted. This is as well the default value. + // "read-write", to disallow connections to read-only servers, hot standbys for example. + // @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com + // @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/ + // + // The query SHOW transaction_read_only will be sent upon any successful connection; + // if it returns on, the connection will be closed. + // If multiple hosts were specified in the connection string, + // any remaining servers will be tried just as if the connection attempt had failed. + // The default value of this parameter, any, regards all connections as acceptable. + TargetSessionAttrs TargetSessionType } -func (cc *ConnConfig) networkAddress() (network, address string) { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - // See if host is a valid path, if yes connect with a socket - if _, err := os.Stat(cc.Host); err == nil { - // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = cc.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) - } +// hostAddr represents network end point defined as hostname or IP + port. +type hostAddr struct { + Host string + Port uint16 +} + +// Network returns the address's network name, "tcp". +func (a *hostAddr) Network() string { return "tcp" } + +// String implements net.Addr String method. +func (a *hostAddr) String() string { + if a == nil { + return "" } - return network, address + return net.JoinHostPort(a.Host, strconv.Itoa(int(a.Port))) +} + +func (cc *ConnConfig) networkAddresses() ([]net.Addr, error) { + // See if host is a valid path, if yes connect with a unix socket + if _, err := os.Stat(cc.Host); err == nil { + // For backward compatibility accept socket file paths -- but directories are now preferred + network := "unix" + address := cc.Host + + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10) + } + + addrs := []net.Addr{ + &net.UnixAddr{Name: address, Net: network}, + } + + return addrs, nil + } + + if cc.Host == "" { + addrs := []net.Addr{ + &net.TCPAddr{Port: int(cc.Port)}, + } + + return addrs, nil + } + + var addrs []net.Addr + + hostports := strings.Split(cc.Host, ",") + for i, hostport := range hostports { + if hostport == "" { + return nil, fmt.Errorf("multi-host part %d is empty, at least host or port must be defined", i) + } + + // It's not possible to use net.TCPAddr here, cuz host may be hostname. + addr := hostAddr{ + Host: hostport, + Port: cc.Port, + } + + pos := strings.IndexByte(hostport, ':') + if pos != -1 { + p, err := strconv.ParseUint(hostport[pos+1:], 10, 16) + if err != nil { + return nil, fmt.Errorf("multi-host part %d (%s) has invalid port format", i, hostport) + } + + addr.Host = hostport[:pos] + addr.Port = uint16(p) + } + + addrs = append(addrs, &addr) + } + + return addrs, nil } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -145,6 +258,10 @@ type Conn struct { ConnInfo *pgtype.ConnInfo frontend *pgproto3.Frontend + + // In case of Multiple Hosts we need to know what addr was used to connect. + // This address will be used to send a cancellation request. + addr net.Addr } // PreparedStatement is a description of a prepared statement @@ -190,7 +307,8 @@ type Identifier []string func (ident Identifier) Sanitize() string { parts := make([]string, len(ident)) for i := range ident { - parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"` + s := strings.Replace(ident[i], string([]byte{0}), "", -1) + parts[i] = `"` + strings.Replace(s, `"`, `""`, -1) + `"` } return strings.Join(parts, ".") } @@ -262,33 +380,123 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } } + if err := c.config.TargetSessionAttrs.isValid(); err != nil { + return nil, err + } + c.onNotice = config.OnNotice - network, address := c.config.networkAddress() if c.config.Dial == nil { d := defaultDialer() c.config.Dial = d.Dial } - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) - } - err = c.connect(config, network, address, config.TLSConfig) - if err != nil && config.UseFallbackTLS { - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) - } - err = c.connect(config, network, address, config.FallbackTLSConfig) - } - + addrs, err := c.config.networkAddresses() if err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) - } return nil, err } - return c, nil + var errs []error + for _, addr := range addrs { + network, address := addr.Network(), addr.String() + + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{ + "network": network, + "address": address, + }) + } + + err = c.connect(config, network, address, config.TLSConfig) + if err != nil && config.UseFallbackTLS { + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) + } + err = c.connect(config, network, address, config.FallbackTLSConfig) + } + + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "connect failed", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) + } + + // On any auth errors return immediately + if pgErr, ok := err.(PgError); ok { + switch pgErr.Code { + // @see: https://www.postgresql.org/docs/current/errcodes-appendix.html + case "28000", "28P01": // Invalid Authorization Specification + return nil, pgErr + } + } + + errs = append(errs, err) + continue + } + + err = c.checkWritable() + if err != nil { + c.die(err) + + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "host is not writable", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) + } + + errs = append(errs, err) + continue + } + + c.addr = addr + + return c, nil + } + + // To keep backwards compatibility, if specific error type expected. + if len(errs) == 1 { + return nil, errs[0] + } + + errmsgs := make([]string, len(errs)) + for i, err := range errs { + errmsgs[i] = err.Error() + } + + return nil, errors.New(strings.Join(errmsgs, "; ")) +} + +func (c *Conn) checkWritable() error { + if !c.config.TargetSessionAttrs.writableRequired() { + return nil + } + + var st string + err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}). + Scan(&st) + + if err != nil { + return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state") + } + + switch st { + case "on": + return errors.New("writable transactions disabled by server") + case "off": + // If transaction_read_only = off, then connection is writable. + return nil + } + + return errors.New("unexpected \"transaction_read_only\" status") } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { @@ -403,6 +611,7 @@ func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { end from pg_type t left join pg_type base_type on t.typelem=base_type.oid +left join pg_class base_cls ON base_type.typrelid = base_cls.oid left join pg_namespace nsp on t.typnamespace=nsp.oid where ( t.typtype in('b', 'p', 'r', 'e') @@ -750,6 +959,10 @@ func (old ConnConfig) Merge(other ConnConfig) ConnConfig { cc.PreferSimpleProtocol = old.PreferSimpleProtocol || other.PreferSimpleProtocol + if other.TargetSessionAttrs != "" { + cc.TargetSessionAttrs = other.TargetSessionAttrs + } + cc.RuntimeParams = make(map[string]string) for k, v := range old.RuntimeParams { cc.RuntimeParams[k] = v @@ -777,16 +990,26 @@ func ParseURI(uri string) (ConnConfig, error) { cp.Password, _ = url.User.Password() } - parts := strings.SplitN(url.Host, ":", 2) - cp.Host = parts[0] - if len(parts) == 2 { - p, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return cp, err + hasMuliHosts := strings.IndexByte(url.Host, ',') != -1 + if !hasMuliHosts { + parts := strings.SplitN(url.Host, ":", 2) + cp.Host = parts[0] + if len(parts) == 2 { + p, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return cp, err + } + cp.Port = uint16(p) } - cp.Port = uint16(p) + } else { + cp.Host = url.Host } + cp.Database = strings.TrimLeft(url.Path, "/") + cp.TargetSessionAttrs = TargetSessionType(url.Query().Get("target_session_attrs")) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { timeout, err := strconv.ParseInt(pgtimeout, 10, 64) @@ -810,11 +1033,12 @@ func ParseURI(uri string) (ConnConfig, error) { } ignoreKeys := map[string]struct{}{ - "connect_timeout": {}, - "sslcert": {}, - "sslkey": {}, - "sslmode": {}, - "sslrootcert": {}, + "connect_timeout": {}, + "sslcert": {}, + "sslkey": {}, + "sslmode": {}, + "sslrootcert": {}, + "target_session_attrs": {}, } cp.RuntimeParams = make(map[string]string) @@ -834,6 +1058,7 @@ func ParseURI(uri string) (ConnConfig, error) { if cp.Password == "" { pgpass(&cp) } + return cp, nil } @@ -859,6 +1084,7 @@ func ParseDSN(s string) (ConnConfig, error) { cp.RuntimeParams = make(map[string]string) + var hostval, portval string for _, b := range m { switch b[1] { case "user": @@ -866,13 +1092,9 @@ func ParseDSN(s string) (ConnConfig, error) { case "password": cp.Password = b[2] case "host": - cp.Host = b[2] + hostval = b[2] case "port": - p, err := strconv.ParseUint(b[2], 10, 16) - if err != nil { - return cp, err - } - cp.Port = uint16(p) + portval = b[2] case "dbname": cp.Database = b[2] case "sslmode": @@ -891,23 +1113,93 @@ func ParseDSN(s string) (ConnConfig, error) { d := defaultDialer() d.Timeout = time.Duration(timeout) * time.Second cp.Dial = d.Dial + case "target_session_attrs": + cp.TargetSessionAttrs = TargetSessionType(b[2]) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } default: cp.RuntimeParams[b[1]] = b[2] } } - err := configTLS(tlsArgs, &cp) + host, port, err := parseHostPortDSN(hostval, portval) if err != nil { return cp, err } + + cp.Host, cp.Port = host, port + + err = configTLS(tlsArgs, &cp) + if err != nil { + return cp, err + } + if cp.Password == "" { pgpass(&cp) } + return cp, nil } -// ParseConnectionString parses either a URI or a DSN connection string. -// see ParseURI and ParseDSN for details. +func parseHostPortDSN(hostval, portval string) (host string, port uint16, err error) { + if portval == "" { + return hostval, 0, nil + } + + hosts := strings.Split(hostval, ",") + ports := strings.Split(portval, ",") + + if len(ports) == 1 { + port, err := parsePort(portval) + if err != nil { + return "", 0, errors.Errorf("invalid port: %v", err) + } + + return hostval, port, nil + } + + if len(hosts) != len(ports) { + return "", 0, errors.New("the number of hosts and ports must be the same") + } + + hostports := make([]string, len(hosts)) + for i, host := range hosts { + hostports[i] = host + ":" + ports[i] + } + + return strings.Join(hostports, ","), 0, nil +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +// ParseConnectionString parses either a URI or a DSN connection string and builds ConnConfig. +// +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// +// ParseConnectionString supports specifying multiple hosts in similar manner to libpq. +// Host and port may include comma separated values that will be tried in order. +// This can be used as part of a high availability system. +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// +// # Example DSN +// user=jack password=secret host=host1,host2,host3 port=5432,5433,5434 dbname=mydb sslmode=verify-ca func ParseConnectionString(s string) (ConnConfig, error) { if u, err := url.Parse(s); err == nil && u.Scheme != "" { return ParseURI(s) @@ -932,6 +1224,8 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS +// @see: https://www.postgresql.org/docs/10/libpq-envars.html // // Important TLS Security Notes: // ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This @@ -977,6 +1271,11 @@ func ParseEnvLibpq() (ConnConfig, error) { } } + cc.TargetSessionAttrs = TargetSessionType(os.Getenv("PGTARGETSESSIONATTRS")) + if err := cc.TargetSessionAttrs.isValid(); err != nil { + return cc, err + } + tlsArgs := configTLSArgs{ sslMode: os.Getenv("PGSSLMODE"), sslKey: os.Getenv("PGSSLKEY"), @@ -1692,8 +1991,7 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - network, address := c.config.networkAddress() - cancelConn, err := c.config.Dial(network, address) + cancelConn, err := c.config.Dial(c.addr.Network(), c.addr.String()) if err != nil { return err } diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 096e1354..2ca84ac3 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -7,6 +7,8 @@ import ( // "go/build" // "io/ioutil" // "path" + // "net" + // "time" "github.com/jackc/pgx" ) @@ -14,6 +16,7 @@ import ( var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // To skip tests for specific connection / authentication types set that connection param to nil +var multihostConnConfig *pgx.ConnConfig = nil var tcpConnConfig *pgx.ConnConfig = nil var unixSocketConnConfig *pgx.ConnConfig = nil var md5ConnConfig *pgx.ConnConfig = nil @@ -24,6 +27,7 @@ var customDialerConnConfig *pgx.ConnConfig = nil var replicationConnConfig *pgx.ConnConfig = nil var cratedbConnConfig *pgx.ConnConfig = nil +// var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} // var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index cf29a743..fbfb5252 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -5,9 +5,12 @@ import ( "github.com/jackc/pgx" "os" "strconv" + "net" + "time" ) var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"} var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_pool.go b/conn_pool.go index 47a0b391..95e1b015 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -110,6 +110,25 @@ func (p *ConnPool) Acquire() (*Conn, error) { return c, err } +func (p *ConnPool) AcquireEx(ctx context.Context) (*Conn, error) { + var deadline *time.Time + + if p.acquireTimeout > 0 { + tmp := time.Now().Add(p.acquireTimeout) + deadline = &tmp + } + + ctxDeadline, ok := ctx.Deadline() + if ok && (deadline == nil || ctxDeadline.Before(*deadline)) { + deadline = &ctxDeadline + } + + p.cond.L.Lock() + c, err := p.acquire(deadline) + p.cond.L.Unlock() + return c, err +} + // deadlinePassed returns true if the given deadline has passed. func (p *ConnPool) deadlinePassed(deadline *time.Time) bool { return deadline != nil && time.Now().After(*deadline) @@ -319,7 +338,7 @@ func (p *ConnPool) createConnection() (*Conn, error) { func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { p.inProgressConnects++ p.cond.L.Unlock() - c, err := Connect(p.config) + c, err := connect(p.config, p.connInfo.DeepCopy()) p.cond.L.Lock() p.inProgressConnects-- @@ -341,7 +360,8 @@ func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { } for _, ps := range p.preparedStatements { - if _, err := c.Prepare(ps.Name, ps.SQL); err != nil { + opts := &PrepareExOptions{ParameterOIDs: ps.ParameterOIDs} + if _, err := c.PrepareEx(context.Background(), ps.Name, ps.SQL, opts); err != nil { c.die(err) return nil, err } diff --git a/conn_pool_test.go b/conn_pool_test.go index 84a74aed..83bdf1fd 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -45,6 +45,12 @@ func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) return c, time.Since(startTime), err } +func acquireExWithTimeTaken(pool *pgx.ConnPool, ctx context.Context) (*pgx.Conn, time.Duration, error) { + startTime := time.Now() + c, err := pool.AcquireEx(ctx) + return c, time.Since(startTime), err +} + func TestNewConnPool(t *testing.T) { t.Parallel() @@ -315,6 +321,144 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { } } +func TestPoolWithAcquireExContextTimeoutSet(t *testing.T) { + t.Parallel() + + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 2 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < ctxTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) + } +} + +func TestPoolWithAcquireExPoolTimeoutLower(t *testing.T) { + t.Parallel() + + connAllocTimeout := 2 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 5 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < connAllocTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) + } + if timeTaken > ctxTimeout { + t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", ctxTimeout, timeTaken) + } +} + +func TestPoolWithAcquireExPoolTimeoutHigher(t *testing.T) { + t.Parallel() + + connAllocTimeout := 5 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 2 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < ctxTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) + } + if timeTaken > connAllocTimeout { + t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", connAllocTimeout, timeTaken) + } +} + +func TestPoolWithoutAcquireExTimeoutSet(t *testing.T) { + t.Parallel() + + maxConnections := 1 + pool := createConnPool(t, maxConnections) + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, maxConnections) + + // ... then try to consume 1 more. It should hang forever. + // To unblock it we release the previously taken connection in a goroutine. + stopDeadWaitTimeout := 5 * time.Second + timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() { + releaseAllConnections(pool, allConnections) + }) + defer timer.Stop() + + conn, timeTaken, err := acquireExWithTimeTaken(pool, context.Background()) + if err == nil { + pool.Release(conn) + } else { + t.Fatalf("Expected error to be nil, instead it was '%v'", err) + } + if timeTaken < stopDeadWaitTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) + } +} + func TestPoolErrClosedPool(t *testing.T) { t.Parallel() diff --git a/conn_test.go b/conn_test.go index 6ca00c6d..c6ce50cc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -84,6 +84,105 @@ func TestConnect(t *testing.T) { } } +func TestConnectWithMultiHost(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + conn, err := pgx.Connect(*multihostConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err = conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithMultiHostWritable(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + connConfig := *multihostConnConfig + connConfig.TargetSessionAttrs = pgx.ReadWriteTargetSession + + conn := mustConnect(t, connConfig) + defer closeConn(t, conn) + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err := conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + var st string + err = conn.QueryRow("SHOW transaction_read_only").Scan(&st) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + + if st == "on" { + t.Error("Connection is not writable") + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithUnixSocketDirectory(t *testing.T) { t.Parallel() @@ -521,6 +620,38 @@ func TestParseURI(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "foo.example.com:5432,bar.example.com:5432", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost,10.10.20.30/mydb?application_name=pgxtest&target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost,10.10.20.30", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + }, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { @@ -647,6 +778,50 @@ func TestParseDSN(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "user=jack host=localhost1,localhost2 dbname=mydb connect_timeout=10", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost1,localhost2", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=100.200.220.50,localhost43 port=5432,5433 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "100.200.220.50:5432,localhost43:5433", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { @@ -1195,6 +1370,32 @@ func TestExecFailure(t *testing.T) { } } +func TestExecDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + _, err := conn.Exec(`update t set n=n+1 where id='b'`) + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestExecFailureWithArguments(t *testing.T) { t.Parallel() @@ -2142,6 +2343,24 @@ func TestSetLogLevel(t *testing.T) { } } +func TestIdentifierSanitizeNullSentToServer(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ident := pgx.Identifier{"foo" + string([]byte{0}) + "bar"} + + var n int64 + err := conn.QueryRow(`select 1 as ` + ident.Sanitize()).Scan(&n) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatal("unexpected n") + } +} + func TestIdentifierSanitize(t *testing.T) { t.Parallel() @@ -2169,6 +2388,10 @@ func TestIdentifierSanitize(t *testing.T) { ident: pgx.Identifier{`you should " not do this`, `please don't`}, expected: `"you should "" not do this"."please don't"`, }, + { + ident: pgx.Identifier{`you should ` + string([]byte{0}) + `not do this`}, + expected: `"you should not do this"`, + }, } for i, tt := range tests { diff --git a/doc.go b/doc.go index 5808c09d..0c2b35d3 100644 --- a/doc.go +++ b/doc.go @@ -225,7 +225,7 @@ notification. return nil } - if notification, err := conn.WaitForNotification(time.Second); err != nil { + if notification, err := conn.WaitForNotification(context.TODO()); err != nil { // do something with notification } diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 4a64b506..97093968 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -210,6 +210,7 @@ func PgxInitSteps() []Step { end from pg_type t left join pg_type base_type on t.typelem=base_type.oid +left join pg_class base_cls ON base_type.typrelid = base_cls.oid left join pg_namespace nsp on t.typnamespace=nsp.oid where ( t.typtype in('b', 'p', 'r', 'e') diff --git a/pgtype/convert.go b/pgtype/convert.go index 5dfb738e..029e3d48 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -149,7 +149,7 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { switch refVal.Kind() { case reflect.Ptr: if refVal.IsNil() { - return time.Time{}, false + return nil, false } convVal := refVal.Elem().Interface() return convVal, true @@ -160,7 +160,28 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { return refVal.Convert(timeType).Interface(), true } - return time.Time{}, false + return nil, false +} + +// underlyingUUIDType gets the underlying type that can be converted to [16]byte +func underlyingUUIDType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + uuidType := reflect.TypeOf([16]byte{}) + if refVal.Type().ConvertibleTo(uuidType) { + return refVal.Convert(uuidType).Interface(), true + } + + return nil, false } // underlyingSliceType gets the underlying slice type @@ -401,6 +422,14 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { } } + if dstVal.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)) + nextDst := dstPtr.Convert(baseArrayType) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + } + return nil, false } diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 5e1eead5..8d33d8f8 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -39,7 +39,7 @@ func (dst *UUID) Set(src interface{}) error { } *dst = UUID{Bytes: uuid, Status: Present} default: - if originalSrc, ok := underlyingPtrType(src); ok { + if originalSrc, ok := underlyingUUIDType(src); ok { return dst.Set(originalSrc) } return errors.Errorf("cannot convert %v to UUID", value) diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 162d999f..1eddeda1 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -15,6 +15,8 @@ func TestUUIDTranscode(t *testing.T) { }) } +type SomeUUIDType [16]byte + func TestUUIDSet(t *testing.T) { successfulTests := []struct { source interface{} @@ -32,6 +34,10 @@ func TestUUIDSet(t *testing.T) { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, + { + source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, { source: ([]byte)(nil), result: pgtype.UUID{Status: pgtype.Null}, @@ -86,6 +92,21 @@ func TestUUIDAssignTo(t *testing.T) { } } + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst SomeUUIDType + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + { src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst string diff --git a/query.go b/query.go index 5c6cbf7f..bf4ec561 100644 --- a/query.go +++ b/query.go @@ -69,6 +69,25 @@ func (rows *Rows) Close() { return } + // If there is no error and a batch operation is not in progress read until we get the ReadyForQuery message or the + // ErrorResponse. This is necessary to detect a deferred constraint violation where the ErrorResponse is sent after + // CommandComplete. + if rows.err == nil && rows.batch == nil && rows.conn.pendingReadyForQueryCount == 1 { + for rows.conn.pendingReadyForQueryCount > 0 { + msg, err := rows.conn.rxMsg() + if err != nil { + rows.err = err + break + } + + err = rows.conn.processContextFreeMsg(msg) + if err != nil { + rows.err = err + break + } + } + } + if rows.unlockConn { rows.conn.unlock() rows.unlockConn = false diff --git a/query_test.go b/query_test.go index 06b7b8b7..ea1fd66e 100644 --- a/query_test.go +++ b/query_test.go @@ -14,7 +14,7 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" - "github.com/satori/go.uuid" + uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" ) @@ -424,6 +424,47 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { } +// https://github.com/jackc/pgx/issues/570 +func TestConnQueryDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + rows, err := conn.Query(`update t set n=n+1 where id='b' returning *`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + if rows.Err() == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := rows.Err().(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestQueryEncodeError(t *testing.T) { t.Parallel() diff --git a/stdlib/sql.go b/stdlib/sql.go index ec5933f3..3cd2d941 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -43,8 +43,8 @@ // // AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard // database/sql.DB connection pool. This allows operations that must be -// performed on a single connection, but should not be run in a transaction or -// to use pgx specific functionality. +// performed on a single connection without running in a transaction, and it +// supports operations that use pgx specific functionality. // // conn, err := stdlib.AcquireConn(db) // if err != nil { @@ -277,7 +277,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e pgxOpts.IsoLevel = pgx.ReadUncommitted case sql.LevelReadCommitted: pgxOpts.IsoLevel = pgx.ReadCommitted - case sql.LevelSnapshot: + case sql.LevelRepeatableRead, sql.LevelSnapshot: pgxOpts.IsoLevel = pgx.RepeatableRead case sql.LevelSerializable: pgxOpts.IsoLevel = pgx.Serializable diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index cf2b91b1..895ee583 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -629,6 +629,7 @@ func TestConnBeginTxIsolation(t *testing.T) { {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, + {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, } From 2d9d8dc52ac211c6191c08e050c03588aa633038 Mon Sep 17 00:00:00 2001 From: Joshua Barone Date: Thu, 12 Sep 2019 10:13:13 -0500 Subject: [PATCH 67/70] replace dsn parser with simple parser, rather than regex --- conn.go | 79 +++++++++++++++++++++++++++++++++++----------- conn_test.go | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 18 deletions(-) diff --git a/conn.go b/conn.go index b4469af9..b67eb98a 100644 --- a/conn.go +++ b/conn.go @@ -17,7 +17,6 @@ import ( "os/user" "path/filepath" "reflect" - "regexp" "strconv" "strings" "sync" @@ -1062,7 +1061,7 @@ func ParseURI(uri string) (ConnConfig, error) { return cp, nil } -var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) +var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} // ParseDSN parses a database DSN (data source name) into a ConnConfig // @@ -1078,35 +1077,79 @@ var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) func ParseDSN(s string) (ConnConfig, error) { var cp ConnConfig - m := dsnRegexp.FindAllStringSubmatch(s, -1) - tlsArgs := configTLSArgs{} cp.RuntimeParams = make(map[string]string) var hostval, portval string - for _, b := range m { - switch b[1] { + for len(s) > 0 { + var key, val string + eqIdx := strings.IndexRune(s, '=') + if eqIdx < 0 { + return cp, errors.New("invalid dsn") + } + + key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") + s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") + if s[0] != '\'' { + end := 0 + for ; end < len(s); end++ { + if asciiSpace[s[end]] == 1 { + break + } + if s[end] == '\\' { + end++ + } + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } else { // quoted string + s = s[1:] + end := 0 + for ; end < len(s); end++ { + if s[end] == '\'' { + break + } + if s[end] == '\\' { + end++ + } + } + if end == len(s) { + return cp, errors.New("unterminated quoted string in connection info string") + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } + + switch key { case "user": - cp.User = b[2] + cp.User = val case "password": - cp.Password = b[2] + cp.Password = val case "host": - hostval = b[2] + hostval = val case "port": - portval = b[2] + portval = val case "dbname": - cp.Database = b[2] + cp.Database = val case "sslmode": - tlsArgs.sslMode = b[2] + tlsArgs.sslMode = val case "sslrootcert": - tlsArgs.sslRootCert = b[2] + tlsArgs.sslRootCert = val case "sslcert": - tlsArgs.sslCert = b[2] + tlsArgs.sslCert = val case "sslkey": - tlsArgs.sslKey = b[2] + tlsArgs.sslKey = val case "connect_timeout": - timeout, err := strconv.ParseInt(b[2], 10, 64) + timeout, err := strconv.ParseInt(val, 10, 64) if err != nil { return cp, err } @@ -1114,12 +1157,12 @@ func ParseDSN(s string) (ConnConfig, error) { d.Timeout = time.Duration(timeout) * time.Second cp.Dial = d.Dial case "target_session_attrs": - cp.TargetSessionAttrs = TargetSessionType(b[2]) + cp.TargetSessionAttrs = TargetSessionType(val) if err := cp.TargetSessionAttrs.isValid(); err != nil { return cp, err } default: - cp.RuntimeParams[b[1]] = b[2] + cp.RuntimeParams[key] = val } } diff --git a/conn_test.go b/conn_test.go index c6ce50cc..42e9c00b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -717,6 +717,38 @@ func TestParseDSN(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack's", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "sooper\\secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { url: "user=jack host=localhost port=5432 dbname=mydb", connParams: pgx.ConnConfig{ @@ -822,6 +854,62 @@ func TestParseDSN(t *testing.T) { TargetSessionAttrs: pgx.ReadWriteTargetSession, }, }, + { + url: "user='jack' host='localhost' dbname='mydb'", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user='jack\\'s' host='localhost' dbname='mydb'", + connParams: pgx.ConnConfig{ + User: "jack's", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user='jack' password='' host='localhost' dbname='mydb'", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb'", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, } for i, tt := range tests { From d52bd74254a9b2a19ebac3b10ca590e99f58d6af Mon Sep 17 00:00:00 2001 From: David Hudson Date: Fri, 13 Sep 2019 16:37:38 +0100 Subject: [PATCH 68/70] pgtype: Add ext type for gofrs uuid implementation Add ext type for https://github.com/gofrs/uuid uuid type. Change test and README from github.com/satori/go.uuid to github.com/gofrs/uuid. The reason is due to this issue: https://github.com/satori/go.uuid/issues/73. This was taken on board and fixed in the community project of gofrs. The gofrs implementation has the same interface as the original. --- README.md | 2 +- pgtype/ext/gofrs-uuid/uuid.go | 161 +++++++++++++++++++++++++++++ pgtype/ext/gofrs-uuid/uuid_test.go | 97 +++++++++++++++++ query_test.go | 6 +- travis/install.bash | 1 + 5 files changed, 263 insertions(+), 4 deletions(-) create mode 100644 pgtype/ext/gofrs-uuid/uuid.go create mode 100644 pgtype/ext/gofrs-uuid/uuid_test.go diff --git a/README.md b/README.md index 0a4cacc3..1c466c11 100644 --- a/README.md +++ b/README.md @@ -85,11 +85,11 @@ skip tests for connection types that are not configured. To setup the normal test environment, first install these dependencies: go get github.com/cockroachdb/apd + go get github.com/gofrs/uuid go get github.com/hashicorp/go-version go get github.com/jackc/fake go get github.com/lib/pq go get github.com/pkg/errors - go get github.com/satori/go.uuid go get github.com/shopspring/decimal go get github.com/sirupsen/logrus go get go.uber.org/zap diff --git a/pgtype/ext/gofrs-uuid/uuid.go b/pgtype/ext/gofrs-uuid/uuid.go new file mode 100644 index 00000000..e859f6ef --- /dev/null +++ b/pgtype/ext/gofrs-uuid/uuid.go @@ -0,0 +1,161 @@ +package uuid + +import ( + "database/sql/driver" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgtype" +) + +var errUndefined = errors.New("cannot encode status undefined") + +type UUID struct { + UUID uuid.UUID + Status pgtype.Status +} + +func (dst *UUID) Set(src interface{}) error { + switch value := src.(type) { + case uuid.UUID: + *dst = UUID{UUID: value, Status: pgtype.Present} + case [16]byte: + *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present} + case []byte: + if len(value) != 16 { + return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + } + *dst = UUID{Status: pgtype.Present} + copy(dst.UUID[:], value) + case string: + uuid, err := uuid.FromString(value) + if err != nil { + return err + } + *dst = UUID{UUID: uuid, Status: pgtype.Present} + default: + // If all else fails see if pgtype.UUID can handle it. If so, translate through that. + pgUUID := &pgtype.UUID{} + if err := pgUUID.Set(value); err != nil { + return errors.Errorf("cannot convert %v to UUID", value) + } + + *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status} + } + + return nil +} + +func (dst *UUID) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst.UUID + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *UUID) AssignTo(dst interface{}) error { + switch src.Status { + case pgtype.Present: + switch v := dst.(type) { + case *uuid.UUID: + *v = src.UUID + case *[16]byte: + *v = [16]byte(src.UUID) + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.UUID[:]) + return nil + case *string: + *v = src.UUID.String() + return nil + default: + if nextDst, retry := pgtype.GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + case pgtype.Null: + return pgtype.NullAssignTo(dst) + } + + return errors.Errorf("cannot assign %v into %T", src, dst) +} + +func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = UUID{Status: pgtype.Null} + return nil + } + + u, err := uuid.FromString(string(src)) + if err != nil { + return err + } + + *dst = UUID{UUID: u, Status: pgtype.Present} + return nil +} + +func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = UUID{Status: pgtype.Null} + return nil + } + + if len(src) != 16 { + return errors.Errorf("invalid length for UUID: %v", len(src)) + } + + *dst = UUID{Status: pgtype.Present} + copy(dst.UUID[:], src) + return nil +} + +func (src *UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case pgtype.Null: + return nil, nil + case pgtype.Undefined: + return nil, errUndefined + } + + return append(buf, src.UUID.String()...), nil +} + +func (src *UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case pgtype.Null: + return nil, nil + case pgtype.Undefined: + return nil, errUndefined + } + + return append(buf, src.UUID[:]...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUID) Scan(src interface{}) error { + if src == nil { + *dst = UUID{Status: pgtype.Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *UUID) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/ext/gofrs-uuid/uuid_test.go b/pgtype/ext/gofrs-uuid/uuid_test.go new file mode 100644 index 00000000..d76edb18 --- /dev/null +++ b/pgtype/ext/gofrs-uuid/uuid_test.go @@ -0,0 +1,97 @@ +package uuid_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/pgtype" + gofrs "github.com/jackc/pgx/pgtype/ext/gofrs-uuid" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &gofrs.UUID{Status: pgtype.Null}, + }) +} + +func TestUUIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result gofrs.UUID + }{ + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r gofrs.UUID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDAssignTo(t *testing.T) { + { + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + +} diff --git a/query_test.go b/query_test.go index ea1fd66e..500399b9 100644 --- a/query_test.go +++ b/query_test.go @@ -11,10 +11,10 @@ import ( "time" "github.com/cockroachdb/apd" + "github.com/gofrs/uuid" "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" - satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" - uuid "github.com/satori/go.uuid" + gofrs "github.com/jackc/pgx/pgtype/ext/gofrs-uuid" "github.com/shopspring/decimal" ) @@ -1140,7 +1140,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t * defer closeConn(t, conn) conn.ConnInfo.RegisterDataType(pgtype.DataType{ - Value: &satori.UUID{}, + Value: &gofrs.UUID{}, Name: "uuid", OID: 2950, }) diff --git a/travis/install.bash b/travis/install.bash index 3c3e44cf..c3b344e5 100755 --- a/travis/install.bash +++ b/travis/install.bash @@ -4,6 +4,7 @@ set -eux go get -u github.com/cockroachdb/apd go get -u github.com/shopspring/decimal go get -u gopkg.in/inconshreveable/log15.v2 +go get -u github.com/gofrs/uuid go get -u github.com/jackc/fake go get -u github.com/lib/pq go get -u github.com/hashicorp/go-version From 6e0acb04d36cd02f92a372399c12e89be96a3f00 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 14 Sep 2019 18:55:06 -0500 Subject: [PATCH 69/70] Fix flickering test --- FAIL: TestPoolWithAcquireExContextTimeoutSet (2.03s) conn_pool_test.go:353: Expected connection allocation time to be at least 2s, instead it was '1.999691391s' These failures were caused by setting the timeout and then measuring how long an acquire took. --- conn_pool_test.go | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/conn_pool_test.go b/conn_pool_test.go index 83bdf1fd..db645e63 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -39,18 +39,6 @@ func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) { } } -func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) { - startTime := time.Now() - c, err := pool.Acquire() - return c, time.Since(startTime), err -} - -func acquireExWithTimeTaken(pool *pgx.ConnPool, ctx context.Context) (*pgx.Conn, time.Duration, error) { - startTime := time.Now() - c, err := pool.AcquireEx(ctx) - return c, time.Since(startTime), err -} - func TestNewConnPool(t *testing.T) { t.Parallel() @@ -282,11 +270,14 @@ func TestPoolWithAcquireTimeoutSet(t *testing.T) { defer releaseAllConnections(pool, allConnections) // ... then try to consume 1 more. It should fail after a short timeout. - _, timeTaken, err := acquireWithTimeTaken(pool) + startTime := time.Now() + _, err = pool.Acquire() if err == nil || err != pgx.ErrAcquireTimeout { t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) } + + timeTaken := time.Now().Sub(startTime) if timeTaken < connAllocTimeout { t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) } @@ -302,6 +293,8 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { // Consume all connections ... allConnections := acquireAllConnections(t, pool, maxConnections) + startTime := time.Now() + // ... then try to consume 1 more. It should hang forever. // To unblock it we release the previously taken connection in a goroutine. stopDeadWaitTimeout := 5 * time.Second @@ -310,12 +303,13 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { }) defer timer.Stop() - conn, timeTaken, err := acquireWithTimeTaken(pool) + conn, err := pool.Acquire() if err == nil { pool.Release(conn) } else { t.Fatalf("Expected error to be nil, instead it was '%v'", err) } + timeTaken := time.Now().Sub(startTime) if timeTaken < stopDeadWaitTimeout { t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) } @@ -339,16 +333,18 @@ func TestPoolWithAcquireExContextTimeoutSet(t *testing.T) { allConnections := acquireAllConnections(t, pool, config.MaxConnections) defer releaseAllConnections(pool, allConnections) + startTime := time.Now() ctxTimeout := 2 * time.Second ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) defer cancel() // ... then try to consume 1 more. It should fail after a short timeout. - _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + _, err = pool.AcquireEx(ctx) if err == nil || err != pgx.ErrAcquireTimeout { t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) } + timeTaken := time.Now().Sub(startTime) if timeTaken < ctxTimeout { t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) } @@ -374,16 +370,18 @@ func TestPoolWithAcquireExPoolTimeoutLower(t *testing.T) { allConnections := acquireAllConnections(t, pool, config.MaxConnections) defer releaseAllConnections(pool, allConnections) + startTime := time.Now() ctxTimeout := 5 * time.Second ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) defer cancel() // ... then try to consume 1 more. It should fail after a short timeout. - _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + _, err = pool.AcquireEx(ctx) if err == nil || err != pgx.ErrAcquireTimeout { t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) } + timeTaken := time.Now().Sub(startTime) if timeTaken < connAllocTimeout { t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) } @@ -412,16 +410,18 @@ func TestPoolWithAcquireExPoolTimeoutHigher(t *testing.T) { allConnections := acquireAllConnections(t, pool, config.MaxConnections) defer releaseAllConnections(pool, allConnections) + startTime := time.Now() ctxTimeout := 2 * time.Second ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) defer cancel() // ... then try to consume 1 more. It should fail after a short timeout. - _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + _, err = pool.AcquireEx(ctx) if err == nil || err != pgx.ErrAcquireTimeout { t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) } + timeTaken := time.Now().Sub(startTime) if timeTaken < ctxTimeout { t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) } @@ -442,18 +442,20 @@ func TestPoolWithoutAcquireExTimeoutSet(t *testing.T) { // ... then try to consume 1 more. It should hang forever. // To unblock it we release the previously taken connection in a goroutine. + startTime := time.Now() stopDeadWaitTimeout := 5 * time.Second timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() { releaseAllConnections(pool, allConnections) }) defer timer.Stop() - conn, timeTaken, err := acquireExWithTimeTaken(pool, context.Background()) + conn, err := pool.AcquireEx(context.Background()) if err == nil { pool.Release(conn) } else { t.Fatalf("Expected error to be nil, instead it was '%v'", err) } + timeTaken := time.Now().Sub(startTime) if timeTaken < stopDeadWaitTimeout { t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) } From c73e7d75061bb42b0282945710f344cfe1113d10 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 14 Sep 2019 20:14:24 -0500 Subject: [PATCH 70/70] Release v3.6.0 --- CHANGELOG.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f22d8d29..8b626330 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,29 @@ +# 3.6.0 (September 14, 2019) + +## Features + +* Improve zap logger (Nicholas Wilson) +* Improve composite type support (David) +* Connect multi-host support (Artemiy Ryabinkov) +* Improve UUID to any [16]byte conversion (Nicholas Wilson) +* Add database/sql repeatable read transaction support (Nathaniel Caza) +* Load user-defined array type oids (Dmitriy Garanzha) +* Add ConnPool.AcquireEx (Kale Blankenship) + +## Fixes + +* Remove 0 bytes when sanitizing identifiers +* Terminate context prior to releasing when killing batch connection (Andrew S. Brown) +* Do not ignore PostgreSQL errors from deferred constraints +* Correct example for WaitForNotification (Ian Stapleton Cordasco) +* Include ParameterOIDs when preparing statements on new pool connections (Kale Blankenship) +* Fix DSN parsing with single quoted values (Joshua Barone) + +## Changes + +* Adjust default read buffer to match default PostgreSQL send buffer (Artemiy Ryabinkov) +* Add https://github.com/gofrs/uuid extension (David Hudson) + # 3.5.0 (June 29, 2019) ## Features