From 0c35c9e630348731e17a6322a935a00b22eb9c19 Mon Sep 17 00:00:00 2001 From: Mitar Date: Sun, 14 Jan 2024 00:46:00 +0100 Subject: [PATCH 01/38] Revert "Document max read and write sizes for large objects" This reverts commit b99e2bb7e0818428092e955cb0ee9cff45504bfd. --- large_objects.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/large_objects.go b/large_objects.go index 67666745..c238ab9c 100644 --- a/large_objects.go +++ b/large_objects.go @@ -67,10 +67,6 @@ type LargeObject struct { } // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. -// -// Write is implemented with a single call to lowrite. The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. -// See definition of PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data in the message, -// len(p) should be no larger than 1 GB - 1 KB. func (o *LargeObject) Write(p []byte) (int, error) { var n int err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n) @@ -86,10 +82,6 @@ func (o *LargeObject) Write(p []byte) (int, error) { } // Read reads up to len(p) bytes into p returning the number of bytes read. -// -// Read is implemented with a single call to loread. PostgreSQL internally allocates a single buffer for the response. -// The largest buffer PostgreSQL will allocate is 1 GB - 1. See definition of MaxAllocSize in the PostgreSQL source -// code. To allow for the other data in the message, len(p) should be no larger than 1 GB - 1 KB. func (o *LargeObject) Read(p []byte) (int, error) { var res []byte err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res) From a4ca0917da3ce5de1dc031680cb158a86f200487 Mon Sep 17 00:00:00 2001 From: Mitar Date: Sun, 14 Jan 2024 02:08:27 +0100 Subject: [PATCH 02/38] Support large large objects. Fixes #1865. --- large_objects.go | 73 ++++++++++++++++++++++++++--------- large_objects_private_test.go | 20 ++++++++++ large_objects_test.go | 9 +++-- 3 files changed, 81 insertions(+), 21 deletions(-) create mode 100644 large_objects_private_test.go diff --git a/large_objects.go b/large_objects.go index c238ab9c..a3028b63 100644 --- a/large_objects.go +++ b/large_objects.go @@ -6,6 +6,11 @@ import ( "io" ) +// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of +// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data +// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB. +var maxLargeObjectMessageLength = 1024*1024*1024 - 1024 + // LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it // was created. // @@ -68,32 +73,64 @@ type LargeObject struct { // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. func (o *LargeObject) Write(p []byte) (int, error) { - var n int - err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n) - if err != nil { - return n, err + nTotal := 0 + for { + expected := len(p) - nTotal + if expected == 0 { + break + } else if expected > maxLargeObjectMessageLength { + expected = maxLargeObjectMessageLength + } + + var n int + err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n) + if err != nil { + return nTotal, err + } + + if n < 0 { + return nTotal, errors.New("failed to write to large object") + } + + nTotal += n + + if n < expected { + return nTotal, errors.New("short write to large object") + } else if n > expected { + return nTotal, errors.New("invalid write to large object") + } } - if n < 0 { - return 0, errors.New("failed to write to large object") - } - - return n, nil + return nTotal, nil } // Read reads up to len(p) bytes into p returning the number of bytes read. func (o *LargeObject) Read(p []byte) (int, error) { - var res []byte - err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res) - copy(p, res) - if err != nil { - return len(res), err + nTotal := 0 + for { + expected := len(p) - nTotal + if expected == 0 { + break + } else if expected > maxLargeObjectMessageLength { + expected = maxLargeObjectMessageLength + } + + var res []byte + err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res) + copy(p[nTotal:], res) + nTotal += len(res) + if err != nil { + return nTotal, err + } + + if len(res) < expected { + return nTotal, io.EOF + } else if len(res) > expected { + return nTotal, errors.New("invalid read of large object") + } } - if len(res) < len(p) { - err = io.EOF - } - return len(res), err + return nTotal, nil } // Seek moves the current location pointer to the new location specified by offset. diff --git a/large_objects_private_test.go b/large_objects_private_test.go new file mode 100644 index 00000000..36eca8f0 --- /dev/null +++ b/large_objects_private_test.go @@ -0,0 +1,20 @@ +package pgx + +import ( + "testing" +) + +// SetMaxLargeObjectMessageLength sets internal maxLargeObjectMessageLength variable +// to the given length for the duration of the test. +// +// Tests using this helper should not use t.Parallel(). +func SetMaxLargeObjectMessageLength(t *testing.T, length int) { + t.Helper() + + original := maxLargeObjectMessageLength + t.Cleanup(func() { + maxLargeObjectMessageLength = original + }) + + maxLargeObjectMessageLength = length +} diff --git a/large_objects_test.go b/large_objects_test.go index 25611bf6..de2eed0d 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -13,7 +13,8 @@ import ( ) func TestLargeObjects(t *testing.T) { - t.Parallel() + // We use a very short limit to test chunking logic. + pgx.SetMaxLargeObjectMessageLength(t, 2) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() @@ -34,7 +35,8 @@ func TestLargeObjects(t *testing.T) { } func TestLargeObjectsSimpleProtocol(t *testing.T) { - t.Parallel() + // We use a very short limit to test chunking logic. + pgx.SetMaxLargeObjectMessageLength(t, 2) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() @@ -160,7 +162,8 @@ func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) { } func TestLargeObjectsMultipleTransactions(t *testing.T) { - t.Parallel() + // We use a very short limit to test chunking logic. + pgx.SetMaxLargeObjectMessageLength(t, 2) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() From 517c654e2cfa7f953bf3e7ac44fe2260cfde669a Mon Sep 17 00:00:00 2001 From: Kirill Malikov Date: Sat, 20 Jan 2024 13:49:26 +0300 Subject: [PATCH 03/38] feat: fast encodeUUID --- pgtype/uuid.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pgtype/uuid.go b/pgtype/uuid.go index b59d6e76..d57c0f2f 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -52,7 +52,19 @@ func parseUUID(src string) (dst [16]byte, err error) { // encodeUUID converts a uuid byte array to UUID standard string form. func encodeUUID(src [16]byte) string { - return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) + var buf [36]byte + + hex.Encode(buf[0:8], src[:4]) + buf[8] = '-' + hex.Encode(buf[9:13], src[4:6]) + buf[13] = '-' + hex.Encode(buf[14:18], src[6:8]) + buf[18] = '-' + hex.Encode(buf[19:23], src[8:10]) + buf[23] = '-' + hex.Encode(buf[24:], src[10:]) + + return string(buf[:]) } // Scan implements the database/sql Scanner interface. From a57bb8caeab831f71b25d77c9dde5d430226cf49 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 23 Jan 2024 16:51:15 +0100 Subject: [PATCH 04/38] Add `AppendRows` helper --- rows.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/rows.go b/rows.go index 1ad91765..17e36cba 100644 --- a/rows.go +++ b/rows.go @@ -417,12 +417,10 @@ type CollectableRow interface { // RowToFunc is a function that scans or otherwise converts row to a T. type RowToFunc[T any] func(row CollectableRow) (T, error) -// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. -func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { +// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T. +func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) { defer rows.Close() - slice := []T{} - for rows.Next() { value, err := fn(rows) if err != nil { @@ -438,6 +436,11 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { return slice, nil } +// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. +func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { + return AppendRows([]T(nil), rows, fn) +} + // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. // CollectOneRow is to CollectRows as QueryRow is to Query. func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { From c90f82a4e3b601ffdd9fdff872539a96987b250e Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Thu, 25 Jan 2024 11:39:17 +0100 Subject: [PATCH 05/38] make properties of QueuedQuery and Batch public, closes #1878 --- batch.go | 42 +++++++++++++++++++++--------------------- conn.go | 56 ++++++++++++++++++++++++++++---------------------------- 2 files changed, 49 insertions(+), 49 deletions(-) diff --git a/batch.go b/batch.go index 9b943621..b9b46d1d 100644 --- a/batch.go +++ b/batch.go @@ -10,8 +10,8 @@ import ( // QueuedQuery is a query that has been queued for execution via a Batch. type QueuedQuery struct { - query string - arguments []any + SQL string + Arguments []any fn batchItemFunc sd *pgconn.StatementDescription } @@ -57,7 +57,7 @@ func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { // Batch queries are a way of bundling multiple queries together to avoid // unnecessary network round trips. A Batch must only be sent once. type Batch struct { - queuedQueries []*QueuedQuery + QueuedQueries []*QueuedQuery } // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. @@ -65,16 +65,16 @@ type Batch struct { // connection's DefaultQueryExecMode. func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { qq := &QueuedQuery{ - query: query, - arguments: arguments, + SQL: query, + Arguments: arguments, } - b.queuedQueries = append(b.queuedQueries, qq) + b.QueuedQueries = append(b.QueuedQueries, qq) return qq } // Len returns number of queries that have been queued so far. func (b *Batch) Len() int { - return len(b.queuedQueries) + return len(b.QueuedQueries) } type BatchResults interface { @@ -227,9 +227,9 @@ func (br *batchResults) Close() error { } // Read and run fn for all remaining items - for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { - if br.b.queuedQueries[br.qqIdx].fn != nil { - err := br.b.queuedQueries[br.qqIdx].fn(br) + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + if br.b.QueuedQueries[br.qqIdx].fn != nil { + err := br.b.QueuedQueries[br.qqIdx].fn(br) if err != nil { br.err = err } @@ -253,10 +253,10 @@ func (br *batchResults) earlyError() error { } func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { - if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { - bi := br.b.queuedQueries[br.qqIdx] - query = bi.query - args = bi.arguments + if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + bi := br.b.QueuedQueries[br.qqIdx] + query = bi.SQL + args = bi.Arguments ok = true br.qqIdx++ } @@ -396,9 +396,9 @@ func (br *pipelineBatchResults) Close() error { } // Read and run fn for all remaining items - for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { - if br.b.queuedQueries[br.qqIdx].fn != nil { - err := br.b.queuedQueries[br.qqIdx].fn(br) + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + if br.b.QueuedQueries[br.qqIdx].fn != nil { + err := br.b.QueuedQueries[br.qqIdx].fn(br) if err != nil { br.err = err } @@ -422,10 +422,10 @@ func (br *pipelineBatchResults) earlyError() error { } func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { - if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { - bi := br.b.queuedQueries[br.qqIdx] - query = bi.query - args = bi.arguments + if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + bi := br.b.QueuedQueries[br.qqIdx] + query = bi.SQL + args = bi.Arguments ok = true br.qqIdx++ } diff --git a/conn.go b/conn.go index 64ae48ca..deb0f48c 100644 --- a/conn.go +++ b/conn.go @@ -903,10 +903,10 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { return &batchResults{ctx: ctx, conn: c, err: err} } - for _, bi := range b.queuedQueries { + for _, bi := range b.QueuedQueries { var queryRewriter QueryRewriter - sql := bi.query - arguments := bi.arguments + sql := bi.SQL + arguments := bi.Arguments optionLoop: for len(arguments) > 0 { @@ -928,8 +928,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { } } - bi.query = sql - bi.arguments = arguments + bi.SQL = sql + bi.Arguments = arguments } // TODO: changing mode per batch? Update Batch.Queue function comment when implemented @@ -939,8 +939,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { } // All other modes use extended protocol and thus can use prepared statements. - for _, bi := range b.queuedQueries { - if sd, ok := c.preparedStatements[bi.query]; ok { + for _, bi := range b.QueuedQueries { + if sd, ok := c.preparedStatements[bi.SQL]; ok { bi.sd = sd } } @@ -961,11 +961,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { var sb strings.Builder - for i, bi := range b.queuedQueries { + for i, bi := range b.QueuedQueries { if i > 0 { sb.WriteByte(';') } - sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) + sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } @@ -984,21 +984,21 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { batch := &pgconn.Batch{} - for _, bi := range b.queuedQueries { + for _, bi := range b.QueuedQueries { sd := bi.sd if sd != nil { - err := c.eqb.Build(c.typeMap, sd, bi.arguments) + err := c.eqb.Build(c.typeMap, sd, bi.Arguments) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) } else { - err := c.eqb.Build(c.typeMap, nil, bi.arguments) + err := c.eqb.Build(c.typeMap, nil, bi.Arguments) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } - batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) + batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) } } @@ -1023,18 +1023,18 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.queuedQueries { + for _, bi := range b.QueuedQueries { if bi.sd == nil { - sd := c.statementCache.Get(bi.query) + sd := c.statementCache.Get(bi.SQL) if sd != nil { bi.sd = sd } else { - if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { bi.sd = distinctNewQueries[idx] } else { sd = &pgconn.StatementDescription{ - Name: stmtcache.StatementName(bi.query), - SQL: bi.query, + Name: stmtcache.StatementName(bi.SQL), + SQL: bi.SQL, } distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueries = append(distinctNewQueries, sd) @@ -1055,17 +1055,17 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.queuedQueries { + for _, bi := range b.QueuedQueries { if bi.sd == nil { - sd := c.descriptionCache.Get(bi.query) + sd := c.descriptionCache.Get(bi.SQL) if sd != nil { bi.sd = sd } else { - if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { bi.sd = distinctNewQueries[idx] } else { sd = &pgconn.StatementDescription{ - SQL: bi.query, + SQL: bi.SQL, } distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueries = append(distinctNewQueries, sd) @@ -1082,13 +1082,13 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.queuedQueries { + for _, bi := range b.QueuedQueries { if bi.sd == nil { - if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { bi.sd = distinctNewQueries[idx] } else { sd := &pgconn.StatementDescription{ - SQL: bi.query, + SQL: bi.SQL, } distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueries = append(distinctNewQueries, sd) @@ -1154,11 +1154,11 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d } // Queue the queries. - for _, bi := range b.queuedQueries { - err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) + for _, bi := range b.QueuedQueries { + err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments) if err != nil { // we wrap the error so we the user can understand which query failed inside the batch - err = fmt.Errorf("error building query %s: %w", bi.query, err) + err = fmt.Errorf("error building query %s: %w", bi.SQL, err) return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } From 0fa533386c98bd9ee1b492114d34f78b236ec2b7 Mon Sep 17 00:00:00 2001 From: Florent Viel Date: Tue, 23 Jan 2024 17:41:56 +0100 Subject: [PATCH 06/38] add ltree pgtype support --- pgtype/ltree.go | 122 +++++++++++++++++++++++++++++++++++++++++++ pgtype/ltree_test.go | 26 +++++++++ 2 files changed, 148 insertions(+) create mode 100644 pgtype/ltree.go create mode 100644 pgtype/ltree_test.go diff --git a/pgtype/ltree.go b/pgtype/ltree.go new file mode 100644 index 00000000..6af31779 --- /dev/null +++ b/pgtype/ltree.go @@ -0,0 +1,122 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +type LtreeCodec struct{} + +func (l LtreeCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +// PreferredFormat returns the preferred format. +func (l LtreeCodec) PreferredFormat() int16 { + return TextFormatCode +} + +// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be +// found then nil is returned. +func (l LtreeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case TextFormatCode: + return (TextCodec)(l).PlanEncode(m, oid, format, value) + case BinaryFormatCode: + switch value.(type) { + case string: + return encodeLtreeCodecBinaryString{} + case []byte: + return encodeLtreeCodecBinaryByteSlice{} + case TextValuer: + return encodeLtreeCodecBinaryTextValuer{} + } + } + + return nil +} + +type encodeLtreeCodecBinaryString struct{} + +func (encodeLtreeCodecBinaryString) Encode(value any, buf []byte) (newBuf []byte, err error) { + ltree := value.(string) + buf = append(buf, 1) + return append(buf, ltree...), nil +} + +type encodeLtreeCodecBinaryByteSlice struct{} + +func (encodeLtreeCodecBinaryByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + ltree := value.([]byte) + buf = append(buf, 1) + return append(buf, ltree...), nil +} + +type encodeLtreeCodecBinaryTextValuer struct{} + +func (encodeLtreeCodecBinaryTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + if !t.Valid { + return nil, nil + } + + buf = append(buf, 1) + return append(buf, t.String...), nil +} + +// PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If +// no plan can be found then nil is returned. +func (l LtreeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case TextFormatCode: + return (TextCodec)(l).PlanScan(m, oid, format, target) + case BinaryFormatCode: + switch target.(type) { + case *string: + return scanPlanBinaryLtreeToString{} + case TextScanner: + return scanPlanBinaryLtreeToTextScanner{} + } + } + + return nil +} + +type scanPlanBinaryLtreeToString struct{} + +func (scanPlanBinaryLtreeToString) Scan(src []byte, target any) error { + version := src[0] + if version != 1 { + return fmt.Errorf("unsupported ltree version %d", version) + } + + p := (target).(*string) + *p = string(src[1:]) + + return nil +} + +type scanPlanBinaryLtreeToTextScanner struct{} + +func (scanPlanBinaryLtreeToTextScanner) Scan(src []byte, target any) error { + version := src[0] + if version != 1 { + return fmt.Errorf("unsupported ltree version %d", version) + } + + scanner := (target).(TextScanner) + return scanner.ScanText(Text{String: string(src[1:]), Valid: true}) +} + +// DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface. +func (l LtreeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return (TextCodec)(l).DecodeDatabaseSQLValue(m, oid, format, src) +} + +// DecodeValue returns src decoded into its default format. +func (l LtreeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + return (TextCodec)(l).DecodeValue(m, oid, format, src) +} diff --git a/pgtype/ltree_test.go b/pgtype/ltree_test.go new file mode 100644 index 00000000..2ec850f5 --- /dev/null +++ b/pgtype/ltree_test.go @@ -0,0 +1,26 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestLtreeCodec(t *testing.T) { + skipCockroachDB(t, "Server does not support type ltree") + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "ltree", []pgxtest.ValueRoundTripTest{ + { + Param: "A.B.C", + Result: new(string), + Test: isExpectedEq("A.B.C"), + }, + { + Param: pgtype.Text{String: "", Valid: true}, + Result: new(pgtype.Text), + Test: isExpectedEq(pgtype.Text{String: "", Valid: true}), + }, + }) +} From bf1c1d7848c087c99fb30a67657a5ec3197f7f32 Mon Sep 17 00:00:00 2001 From: Florent Viel Date: Wed, 24 Jan 2024 09:21:24 +0100 Subject: [PATCH 07/38] create ltree extension in pg setup for tests --- testsetup/postgresql_setup.sql | 1 + 1 file changed, 1 insertion(+) diff --git a/testsetup/postgresql_setup.sql b/testsetup/postgresql_setup.sql index 51414863..837c978a 100644 --- a/testsetup/postgresql_setup.sql +++ b/testsetup/postgresql_setup.sql @@ -1,5 +1,6 @@ -- Create extensions and types. create extension hstore; +create extension ltree; create domain uint64 as numeric(20,0); -- Create users for different types of connections and authentication. From 0819a17da8863cd8bd6819dfed246d6101f67086 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jan 2024 08:42:48 -0600 Subject: [PATCH 08/38] Remove openssl from TLS test setup TLS setup and tests were rather finicky. It seems that openssl 3 encrypts certificates differently than older openssl and it does it in a way Go and/or pgx ssl handling code can't handle. It appears that this related to the use of a deprecated client certificate encryption system. This caused CI to be stuck on Ubuntu 20.04 and recently caused the contributing guide to fail to work on MacOS. Remove openssl from the test setup and replace it with a Go program that generates the certificates. --- .github/workflows/ci.yml | 5 +- CONTRIBUTING.md | 18 +--- ci/setup_test.bash | 15 +-- testsetup/ca.cnf | 6 -- testsetup/generate_certs.go | 187 ++++++++++++++++++++++++++++++++++++ testsetup/localhost.cnf | 13 --- testsetup/pgx_sslcert.cnf | 9 -- 7 files changed, 192 insertions(+), 61 deletions(-) delete mode 100644 testsetup/ca.cnf create mode 100644 testsetup/generate_certs.go delete mode 100644 testsetup/localhost.cnf delete mode 100644 testsetup/pgx_sslcert.cnf diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c1bf91c3..f45c3045 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,10 +9,7 @@ on: jobs: test: name: Test - # Note: The TLS tests are rather finicky. It seems that openssl 3 encrypts certificates differently than older - # openssl and it does it in a way Go and/or pgx ssl handling code can't handle. So stick with Ubuntu 20.04 until - # that is figured out. - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 strategy: matrix: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3eb0da5b..6ed3205c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -79,20 +79,11 @@ echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf -cp testsetup/ca.cnf .testdb -cp testsetup/localhost.cnf .testdb -cp testsetup/pgx_sslcert.cnf .testdb cd .testdb -# Generate a CA public / private key pair. -openssl genrsa -out ca.key 4096 -openssl req -x509 -config ca.cnf -new -nodes -key ca.key -sha256 -days 365 -subj '/O=pgx-test-root' -out ca.pem - -# Generate the certificate for localhost (the server). -openssl genrsa -out localhost.key 2048 -openssl req -new -config localhost.cnf -key localhost.key -out localhost.csr -openssl x509 -req -in localhost.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out localhost.crt -days 364 -sha256 -extfile localhost.cnf -extensions v3_req +# Generate CA, server, and encrypted client certificates. +go run ../testsetup/generate_certs.go # Copy certificates to server directory and set permissions. cp ca.pem $POSTGRESQL_DATA_DIR/root.crt @@ -100,11 +91,6 @@ cp localhost.key $POSTGRESQL_DATA_DIR/server.key chmod 600 $POSTGRESQL_DATA_DIR/server.key cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt -# Generate the certificate for client authentication. -openssl genrsa -des3 -out pgx_sslcert.key -passout pass:certpw 2048 -openssl req -new -config pgx_sslcert.cnf -key pgx_sslcert.key -passin pass:certpw -out pgx_sslcert.csr -openssl x509 -req -in pgx_sslcert.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out pgx_sslcert.crt -days 363 -sha256 -extfile pgx_sslcert.cnf -extensions v3_req - cd .. ``` diff --git a/ci/setup_test.bash b/ci/setup_test.bash index f96d2b72..66ba07d4 100755 --- a/ci/setup_test.bash +++ b/ci/setup_test.bash @@ -16,14 +16,8 @@ then cd testsetup - # Generate a CA public / private key pair. - openssl genrsa -out ca.key 4096 - openssl req -x509 -config ca.cnf -new -nodes -key ca.key -sha256 -days 365 -subj '/O=pgx-test-root' -out ca.pem - - # Generate the certificate for localhost (the server). - openssl genrsa -out localhost.key 2048 - openssl req -new -config localhost.cnf -key localhost.key -out localhost.csr - openssl x509 -req -in localhost.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out localhost.crt -days 364 -sha256 -extfile localhost.cnf -extensions v3_req + # Generate CA, server, and encrypted client certificates. + go run generate_certs.go # Copy certificates to server directory and set permissions. sudo cp ca.pem /var/lib/postgresql/$PGVERSION/main/root.crt @@ -34,11 +28,6 @@ then sudo cp localhost.crt /var/lib/postgresql/$PGVERSION/main/server.crt sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.crt - # Generate the certificate for client authentication. - openssl genrsa -des3 -out pgx_sslcert.key -passout pass:certpw 2048 - openssl req -new -config pgx_sslcert.cnf -key pgx_sslcert.key -passin pass:certpw -out pgx_sslcert.csr - openssl x509 -req -in pgx_sslcert.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out pgx_sslcert.crt -days 363 -sha256 -extfile pgx_sslcert.cnf -extensions v3_req - cp ca.pem /tmp cp pgx_sslcert.key /tmp cp pgx_sslcert.crt /tmp diff --git a/testsetup/ca.cnf b/testsetup/ca.cnf deleted file mode 100644 index bd018037..00000000 --- a/testsetup/ca.cnf +++ /dev/null @@ -1,6 +0,0 @@ -[ req ] -distinguished_name = dn -[ dn ] -commonName = ca -[ ext ] -basicConstraints =CA:TRUE,pathlen:0 diff --git a/testsetup/generate_certs.go b/testsetup/generate_certs.go new file mode 100644 index 00000000..945c6c5e --- /dev/null +++ b/testsetup/generate_certs.go @@ -0,0 +1,187 @@ +// Generates a CA, server certificate, and encrypted client certificate for testing pgx. + +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "os" + "time" +) + +func main() { + // Create the CA + ca := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "pgx-root-ca", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(20, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + panic(err) + } + + caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caKey.PublicKey, caKey) + if err != nil { + panic(err) + } + + err = writePrivateKey("ca.key", caKey) + if err != nil { + panic(err) + } + + err = writeCertificate("ca.pem", caBytes) + if err != nil { + panic(err) + } + + // Create a server certificate signed by the CA for localhost. + serverCert := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + CommonName: "localhost", + }, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(20, 0, 0), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + serverCertPrivKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + + serverBytes, err := x509.CreateCertificate(rand.Reader, serverCert, ca, &serverCertPrivKey.PublicKey, caKey) + if err != nil { + panic(err) + } + + err = writePrivateKey("localhost.key", serverCertPrivKey) + if err != nil { + panic(err) + } + + err = writeCertificate("localhost.crt", serverBytes) + if err != nil { + panic(err) + } + + // Create a client certificate signed by the CA and encrypted. + clientCert := &x509.Certificate{ + SerialNumber: big.NewInt(3), + Subject: pkix.Name{ + CommonName: "pgx_sslcert", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(20, 0, 0), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + clientCertPrivKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + + clientBytes, err := x509.CreateCertificate(rand.Reader, clientCert, ca, &clientCertPrivKey.PublicKey, caKey) + if err != nil { + panic(err) + } + + writeEncryptedPrivateKey("pgx_sslcert.key", clientCertPrivKey, "certpw") + if err != nil { + panic(err) + } + + writeCertificate("pgx_sslcert.crt", clientBytes) + if err != nil { + panic(err) + } +} + +func writePrivateKey(path string, privateKey *rsa.PrivateKey) error { + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("writePrivateKey: %w", err) + } + + err = pem.Encode(file, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + if err != nil { + return fmt.Errorf("writePrivateKey: %w", err) + } + + err = file.Close() + if err != nil { + return fmt.Errorf("writePrivateKey: %w", err) + } + + return nil +} + +func writeEncryptedPrivateKey(path string, privateKey *rsa.PrivateKey, password string) error { + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("writeEncryptedPrivateKey: %w", err) + } + + block, err := x509.EncryptPEMBlock(rand.Reader, "CERTIFICATE", x509.MarshalPKCS1PrivateKey(privateKey), []byte(password), x509.PEMCipher3DES) + if err != nil { + return fmt.Errorf("writeEncryptedPrivateKey: %w", err) + } + + err = pem.Encode(file, block) + if err != nil { + return fmt.Errorf("writeEncryptedPrivateKey: %w", err) + } + + err = file.Close() + if err != nil { + return fmt.Errorf("writeEncryptedPrivateKey: %w", err) + } + + return nil + +} + +func writeCertificate(path string, certBytes []byte) error { + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("writeCertificate: %w", err) + } + + err = pem.Encode(file, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + if err != nil { + return fmt.Errorf("writeCertificate: %w", err) + } + + err = file.Close() + if err != nil { + return fmt.Errorf("writeCertificate: %w", err) + } + + return nil +} diff --git a/testsetup/localhost.cnf b/testsetup/localhost.cnf deleted file mode 100644 index 14dcd57f..00000000 --- a/testsetup/localhost.cnf +++ /dev/null @@ -1,13 +0,0 @@ -[ req ] -default_bits = 2048 -distinguished_name = dn -req_extensions = v3_req -prompt = no -[ dn ] -commonName = localhost -[ v3_req ] -subjectAltName = @alt_names -keyUsage = digitalSignature -extendedKeyUsage = serverAuth -[alt_names] -DNS.1 = localhost diff --git a/testsetup/pgx_sslcert.cnf b/testsetup/pgx_sslcert.cnf deleted file mode 100644 index 2d5d0ff7..00000000 --- a/testsetup/pgx_sslcert.cnf +++ /dev/null @@ -1,9 +0,0 @@ -[ req ] -default_bits = 2048 -distinguished_name = dn -req_extensions = v3_req -prompt = no -[ dn ] -commonName = pgx_sslcert -[ v3_req ] -keyUsage = digitalSignature From 7b5fcac46526c55c6c3ed32812a0002104bae2c3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jan 2024 18:55:59 -0600 Subject: [PATCH 09/38] Add timetz and []timetz OID constants https://github.com/jackc/pgx/issues/1883 --- pgtype/pgtype.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 4c2532d2..08833f87 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -81,6 +81,8 @@ const ( IntervalOID = 1186 IntervalArrayOID = 1187 NumericArrayOID = 1231 + TimetzOID = 1266 + TimetzArrayOID = 1270 BitOID = 1560 BitArrayOID = 1561 VarbitOID = 1562 From 34da2fed9570ec3ff22dfa181323ef9cf702b2ca Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Feb 2024 09:49:56 -0600 Subject: [PATCH 10/38] Improve CopyFrom auto-conversion of text-ish values CopyFrom requires that all values are encoded in the binary format. It already tried to parse strings to values that can then be encoded into the binary format. But it didn't handle types that can be encoded as text and then parsed and converted to binary. It now does. --- copy_from_test.go | 34 ++++++++++++++++++++++++++++++++++ values.go | 6 +++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/copy_from_test.go b/copy_from_test.go index 9da23c04..423337e4 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -803,6 +803,40 @@ func TestConnCopyFromAutomaticStringConversion(t *testing.T) { ensureConnValid(t, conn) } +// https://github.com/jackc/pgx/discussions/1891 +func TestConnCopyFromAutomaticStringConversionArray(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a numeric[] + )`) + + inputRows := [][]interface{}{ + {[]string{"42"}}, + {[]string{"7"}}, + {[]string{"8", "9"}}, + {[][]string{{"10", "11"}, {"12", "13"}}}, + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + // Test reads as int64 and flattened array for simplicity. + rows, _ := conn.Query(ctx, "select * from foo") + nums, err := pgx.CollectRows(rows, pgx.RowTo[[]int64]) + require.NoError(t, err) + require.Equal(t, [][]int64{{42}, {7}, {8, 9}, {10, 11, 12, 13}}, nums) + + ensureConnValid(t, conn) +} + func TestCopyFromFunc(t *testing.T) { t.Parallel() diff --git a/values.go b/values.go index 19c642fa..cab717d0 100644 --- a/values.go +++ b/values.go @@ -55,7 +55,11 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { s, ok := arg.(string) if !ok { - return nil, errors.New("not a string") + textBuf, err := m.Encode(oid, TextFormatCode, arg, nil) + if err != nil { + return nil, errors.New("not a string and cannot be encoded as text") + } + s = string(textBuf) } var v any From fd4411453fbd592586601b71f05eb06ea1c74906 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Feb 2024 10:29:10 -0600 Subject: [PATCH 11/38] Improve Conn.LoadType documentation --- conn.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index deb0f48c..96ed452d 100644 --- a/conn.go +++ b/conn.go @@ -1203,7 +1203,15 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) { return sanitize.SanitizeSQL(sql, valueArgs...) } -// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. +// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be +// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular, +// typeName must be one of the following: +// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered. +// - A composite type name where all field types are already registered. +// - A domain type name where the base type is already registered. +// - An enum type name. +// - A range type name where the element type is already registered. +// - A multirange type name where the element type is already registered. func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) { var oid uint32 From 832b4f97718c2d9d2eb16bbd2fef1d05ede7aab5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Feb 2024 12:25:57 -0600 Subject: [PATCH 12/38] Fix: prepared statement already exists When a conn is going to execute a query, the first thing it does is to deallocate any invalidated prepared statements from the statement cache. However, the statements were removed from the cache regardless of whether the deallocation succeeded. This would cause subsequent calls of the same SQL to fail with "prepared statement already exists" error. This problem is easy to trigger by running a query with a context that is already canceled. This commit changes the deallocate invalidated cached statements logic so that the statements are only removed from the cache if the deallocation was successful on the server. https://github.com/jackc/pgx/issues/1847 --- conn.go | 10 ++++++--- conn_test.go | 29 +++++++++++++++++++++++++++ internal/stmtcache/lru_cache.go | 14 ++++++++----- internal/stmtcache/stmtcache.go | 9 +++++++-- internal/stmtcache/unlimited_cache.go | 12 ++++++++--- 5 files changed, 61 insertions(+), 13 deletions(-) diff --git a/conn.go b/conn.go index 96ed452d..a7a5ef73 100644 --- a/conn.go +++ b/conn.go @@ -1359,12 +1359,12 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error } if c.descriptionCache != nil { - c.descriptionCache.HandleInvalidated() + c.descriptionCache.RemoveInvalidated() } var invalidatedStatements []*pgconn.StatementDescription if c.statementCache != nil { - invalidatedStatements = c.statementCache.HandleInvalidated() + invalidatedStatements = c.statementCache.GetInvalidated() } if len(invalidatedStatements) == 0 { @@ -1376,7 +1376,6 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error for _, sd := range invalidatedStatements { pipeline.SendDeallocate(sd.Name) - delete(c.preparedStatements, sd.Name) } err := pipeline.Sync() @@ -1389,5 +1388,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error return fmt.Errorf("failed to deallocate cached statement(s): %w", err) } + c.statementCache.RemoveInvalidated() + for _, sd := range invalidatedStatements { + delete(c.preparedStatements, sd.Name) + } + return nil } diff --git a/conn_test.go b/conn_test.go index a7f7f2f8..e9415b22 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1338,3 +1338,32 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) { t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not") }) } + +// https://github.com/jackc/pgx/issues/1847 +func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var n int32 + err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + // Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was + // encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn + // we could call conn.statementCache.InvalidateAll() instead. + err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n) + require.Error(t, err) + + ctx2, cancel2 := context.WithCancel(ctx) + cancel2() + err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + + err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + }) +} diff --git a/internal/stmtcache/lru_cache.go b/internal/stmtcache/lru_cache.go index 859345fc..dec83f47 100644 --- a/internal/stmtcache/lru_cache.go +++ b/internal/stmtcache/lru_cache.go @@ -81,12 +81,16 @@ func (c *LRUCache) InvalidateAll() { c.l = list.New() } -// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. -// Typically, the caller will then deallocate them. -func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription { - invalidStmts := c.invalidStmts +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *LRUCache) RemoveInvalidated() { c.invalidStmts = nil - return invalidStmts } // Len returns the number of cached prepared statement descriptions. diff --git a/internal/stmtcache/stmtcache.go b/internal/stmtcache/stmtcache.go index b2940e23..d57bdd29 100644 --- a/internal/stmtcache/stmtcache.go +++ b/internal/stmtcache/stmtcache.go @@ -29,8 +29,13 @@ type Cache interface { // InvalidateAll invalidates all statement descriptions. InvalidateAll() - // HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. - HandleInvalidated() []*pgconn.StatementDescription + // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. + GetInvalidated() []*pgconn.StatementDescription + + // RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a + // call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were + // never seen by the call to GetInvalidated. + RemoveInvalidated() // Len returns the number of cached prepared statement descriptions. Len() int diff --git a/internal/stmtcache/unlimited_cache.go b/internal/stmtcache/unlimited_cache.go index f5f59396..69641329 100644 --- a/internal/stmtcache/unlimited_cache.go +++ b/internal/stmtcache/unlimited_cache.go @@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() { c.m = make(map[string]*pgconn.StatementDescription) } -func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription { - invalidStmts := c.invalidStmts +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *UnlimitedCache) RemoveInvalidated() { c.invalidStmts = nil - return invalidStmts } // Len returns the number of cached prepared statement descriptions. From 7caa448ac8858ee0c3fdfc15d11c32cf3bba34f3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Feb 2024 12:41:59 -0600 Subject: [PATCH 13/38] Skip test on CockroachDB --- conn_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/conn_test.go b/conn_test.go index e9415b22..75861053 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1345,6 +1345,8 @@ func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) { defer cancel() pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "CockroachDB returns decimal instead of integer for integer division") + var n int32 err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n) require.NoError(t, err) From 576b6c88f631354d8deaddf9fd30cdb0632adb1a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Feb 2024 12:50:18 -0600 Subject: [PATCH 14/38] Bump actions/setup-go version This gets rid of some deprecation warnings on Github Actions. --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f45c3045..1e494fdd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -74,7 +74,7 @@ jobs: uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} @@ -137,7 +137,7 @@ jobs: uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} From 6f8f6ede6c67ce6b755feafe5c44f93577545147 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Feb 2024 12:52:29 -0600 Subject: [PATCH 15/38] Update changelog for v5.5.3 --- CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6469c183..4fcbc247 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,13 @@ +# 5.5.3 (February 3, 2024) + +* Fix: prepared statement already exists +* Improve CopyFrom auto-conversion of text-ish values +* Add ltree type support (Florent Viel) +* Make some properties of Batch and QueuedQuery public (Pavlo Golub) +* Add AppendRows function (Edoardo Spadolini) +* Optimize convert UUID [16]byte to string (Kirill Malikov) +* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar) + # 5.5.2 (January 13, 2024) * Allow NamedArgs to start with underscore From 5c63f646f820ca9696fc3515c1caf2a557d562e5 Mon Sep 17 00:00:00 2001 From: Tom Payne Date: Mon, 5 Feb 2024 04:18:15 +0100 Subject: [PATCH 16/38] Add link to github.com/twpayne/pgx-geos --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8b890836..49f2c3d7 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,7 @@ pgerrcode contains constants for the PostgreSQL error codes. * [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) * [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) +* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos)) * [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) From 654dcab93eedc3cbd6ebcf6f8d1d950f314d18d4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 23 Feb 2024 17:40:11 -0600 Subject: [PATCH 17/38] Fix: pgtype.Bits makes copy of data from read buffer It was taking a reference. This would cause the data to be corrupted by future reads. fixes #1909 --- pgtype/bits.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pgtype/bits.go b/pgtype/bits.go index 30558118..e7a1d016 100644 --- a/pgtype/bits.go +++ b/pgtype/bits.go @@ -176,8 +176,10 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error { bitLen := int32(binary.BigEndian.Uint32(src)) rp := 4 + buf := make([]byte, len(src[rp:])) + copy(buf, src[rp:]) - return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true}) + return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true}) } type scanPlanTextAnyToBitsScanner struct{} From 85f15c4b3c76d3dfe139d95f983937203b49d933 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 23 Feb 2024 18:18:03 -0600 Subject: [PATCH 18/38] Fix scan float4 into sql.Scanner https://github.com/jackc/pgx/issues/1911 --- pgtype/float4.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgtype/float4.go b/pgtype/float4.go index 91ca0147..8646d9d2 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -297,12 +297,12 @@ func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr return nil, nil } - var n float64 + var n float32 err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } - return n, nil + return float64(n), nil } func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { From 8896bd697781ed4aee392daa90b90cde142319fe Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Feb 2024 09:24:26 -0600 Subject: [PATCH 19/38] Handle invalid sslkey file https://github.com/jackc/pgx/issues/1915 --- pgconn/config.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pgconn/config.go b/pgconn/config.go index ddde89bd..33a72257 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -721,6 +721,9 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P return nil, fmt.Errorf("unable to read sslkey: %w", err) } block, _ := pem.Decode(buf) + if block == nil { + return nil, errors.New("failed to decode sslkey") + } var pemKey []byte var decryptedKey []byte var decryptedError error From 046f497efb4e92caa9575a0e9c351e4906af14c6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Feb 2024 10:16:18 -0600 Subject: [PATCH 20/38] deallocateInvalidatedCachedStatements now runs in transactions https://github.com/jackc/pgx/issues/1847 --- conn.go | 2 +- conn_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index a7a5ef73..fc72c732 100644 --- a/conn.go +++ b/conn.go @@ -1354,7 +1354,7 @@ order by attnum`, } func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { - if c.pgConn.TxStatus() != 'I' { + if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' { return nil } diff --git a/conn_test.go b/conn_test.go index 75861053..df8c9186 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1369,3 +1369,42 @@ func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) { require.EqualValues(t, 1, n) }) } + +// https://github.com/jackc/pgx/issues/1847 +func TestConnDeallocateInvalidatedCachedStatementsInTransactionWithBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + connString := os.Getenv("PGX_TEST_DATABASE") + config := mustParseConfig(t, connString) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 2 + + conn, err := pgx.ConnectConfig(ctx, config) + require.NoError(t, err) + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, "select $1::int + 1", 1) + require.NoError(t, err) + + _, err = tx.Exec(ctx, "select $1::int + 2", 1) + require.NoError(t, err) + + // This should invalidate the first cached statement. + _, err = tx.Exec(ctx, "select $1::int + 3", 1) + require.NoError(t, err) + + batch := &pgx.Batch{} + batch.Queue("select $1::int + 1", 1) + err = tx.SendBatch(ctx, batch).Close() + require.NoError(t, err) + + err = tx.Rollback(ctx) + require.NoError(t, err) + + ensureConnValid(t, conn) +} From d149d3fe5c50d1d98bd6265d3c928519ba4b3f4b Mon Sep 17 00:00:00 2001 From: David Kurman Date: Sun, 25 Feb 2024 17:56:47 +0200 Subject: [PATCH 21/38] Fix panic in TryFindUnderlyingTypeScanPlan Check if CanConvert before calling reflect.Value.Convert --- pgtype/pgtype.go | 2 +- pgtype/pgtype_test.go | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 08833f87..534ef6d1 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -561,7 +561,7 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex } } - if nextDstType != nil && dstValue.Type() != nextDstType { + if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) { return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index b6e3371f..c397069b 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -35,6 +35,7 @@ func init() { // Test for renamed types type _string string type _bool bool +type _uint8 uint8 type _int8 int8 type _int16 int16 type _int16Slice []int16 @@ -453,6 +454,14 @@ func TestMapScanNullToWrongType(t *testing.T) { assert.False(t, pn.Valid) } +func TestScanToSliceOfRenamedUint8(t *testing.T) { + m := pgtype.NewMap() + var ruint8 []_uint8 + err := m.Scan(pgtype.Int2ArrayOID, pgx.TextFormatCode, []byte("{2,4}"), &ruint8) + assert.NoError(t, err) + assert.Equal(t, []_uint8{2, 4}, ruint8) +} + func TestMapScanTextToBool(t *testing.T) { tests := []struct { name string From 2e84dccaf57b4fa803149884f2149dfa9e923593 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 29 Feb 2024 18:44:01 -0600 Subject: [PATCH 22/38] *Pipeline.getResults should close pipeline on error Otherwise, it might be possible to panic when closing the pipeline if it tries to read a connection that should be closed but still has a fatal error on the wire. https://github.com/jackc/pgx/issues/1920 --- pgconn/pgconn.go | 2 ++ pgconn/pgconn_test.go | 83 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index b287e020..ad81ec60 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -2094,6 +2094,8 @@ func (p *Pipeline) getResults() (results any, err error) { for { msg, err := p.conn.receiveMessage() if err != nil { + p.closed = true + p.err = err p.conn.asyncClose() return nil, normalizeTimeoutError(p.ctx, err) } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 7ca3992e..f04fa79a 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3389,3 +3389,86 @@ func TestSNISupport(t *testing.T) { }) } } + +// https://github.com/jackc/pgx/issues/1920 +func TestFatalErrorReceivedInPipelineMode(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + steps := pgmock.AcceptUnauthenticatedConnRequestSteps() + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + {Name: []byte("mock")}, + }})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + // We shouldn't get anything after the first fatal error. But the reported issue was with PgBouncer so maybe that + // causes the issue. Anyway, a FATAL error after the connection had already been killed could cause a panic. + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + + script := &pgmock.Script{Steps: steps} + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverKeepAlive := make(chan struct{}) + defer close(serverKeepAlive) + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(59 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + + <-serverKeepAlive + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel = context.WithTimeout(ctx, 59*time.Second) + defer cancel() + conn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + + pipeline := conn.StartPipeline(ctx) + pipeline.SendPrepare("s1", "select 1", nil) + pipeline.SendPrepare("s2", "select 2", nil) + pipeline.SendPrepare("s3", "select 3", nil) + err = pipeline.Sync() + require.NoError(t, err) + + _, err = pipeline.GetResults() + require.NoError(t, err) + _, err = pipeline.GetResults() + require.Error(t, err) + + err = pipeline.Close() + require.Error(t, err) +} From 88dfc22ae4aa031783cda90841d5358edd85ff2c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Mar 2024 15:12:20 -0600 Subject: [PATCH 23/38] Fix simple protocol encoding of json.RawMessage The underlying type of json.RawMessage is a []byte so to avoid it being considered binary data we need to handle it specifically. This is done by registerDefaultPgTypeVariants. In addition, handle json.RawMessage in the JSONCodec PlanEncode to avoid it being mutated by json.Marshal. https://github.com/jackc/pgx/issues/1763 --- pgtype/json.go | 17 +++++++++++++++++ pgtype/pgtype_default.go | 2 ++ pgtype/pgtype_test.go | 8 ++++++++ 3 files changed, 27 insertions(+) diff --git a/pgtype/json.go b/pgtype/json.go index 3f1a750f..99628092 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -25,6 +25,11 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod case []byte: return encodePlanJSONCodecEitherFormatByteSlice{} + // Handle json.RawMessage specifically because if it is run through json.Marshal it may be mutated. + // e.g. `{"foo": "bar"}` -> `{"foo":"bar"}`. + case json.RawMessage: + return encodePlanJSONCodecEitherFormatJSONRawMessage{} + // Cannot rely on driver.Valuer being handled later because anything can be marshalled. // // https://github.com/jackc/pgx/issues/1430 @@ -79,6 +84,18 @@ func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (n return buf, nil } +type encodePlanJSONCodecEitherFormatJSONRawMessage struct{} + +func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes := value.(json.RawMessage) + if jsonBytes == nil { + return nil, nil + } + + buf = append(buf, jsonBytes...) + return buf, nil +} + type encodePlanJSONCodecEitherFormatMarshal struct{} func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { diff --git a/pgtype/pgtype_default.go b/pgtype/pgtype_default.go index 58f4b92c..c21ac081 100644 --- a/pgtype/pgtype_default.go +++ b/pgtype/pgtype_default.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/json" "net" "net/netip" "reflect" @@ -173,6 +174,7 @@ func initDefaultMap() { registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz") registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval") registerDefaultPgTypeVariants[string](defaultMap, "text") + registerDefaultPgTypeVariants[json.RawMessage](defaultMap, "json") registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea") registerDefaultPgTypeVariants[net.IP](defaultMap, "inet") diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index c397069b..b670e92b 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -546,6 +546,14 @@ func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) { require.Error(t, err) } +// https://github.com/jackc/pgx/issues/1763 +func TestMapEncodeRawJSONIntoUnknownOID(t *testing.T) { + m := pgtype.NewMap() + buf, err := m.Encode(0, pgtype.TextFormatCode, json.RawMessage(`{"foo": "bar"}`), nil) + require.NoError(t, err) + require.Equal(t, []byte(`{"foo": "bar"}`), buf) +} + func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) { m := pgtype.NewMap() src := []byte{0, 0, 0, 42} From c1b0a01ca75ac9eb3a7dbc1396f583ab5dbf9557 Mon Sep 17 00:00:00 2001 From: Felix <23635466+its-felix@users.noreply.github.com> Date: Sun, 3 Mar 2024 07:30:22 +0100 Subject: [PATCH 24/38] Fix behavior of CollectRows to return empty slice if Rows are empty https://github.com/jackc/pgx/issues/1924 --- rows.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rows.go b/rows.go index 17e36cba..78ef5326 100644 --- a/rows.go +++ b/rows.go @@ -438,7 +438,7 @@ func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) { // CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { - return AppendRows([]T(nil), rows, fn) + return AppendRows([]T{}, rows, fn) } // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. From adbb38f298c76e283ffc7c7a3f571036fea47fd4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Mar 2024 11:24:16 -0600 Subject: [PATCH 25/38] Do not allow protocol messages larger than ~1GB The PostgreSQL server will reject messages greater than ~1 GB anyway. However, worse than that is that a message that is larger than 4 GB could wrap the 32-bit integer message size and be interpreted by the server as multiple messages. This could allow a malicious client to inject arbitrary protocol messages. https://github.com/jackc/pgx/security/advisories/GHSA-mrww-27vc-gghv --- pgconn/pgconn.go | 46 +++++- pgconn/pgconn_test.go | 13 +- pgproto3/authentication_cleartext_password.go | 7 +- pgproto3/authentication_gss.go | 7 +- pgproto3/authentication_gss_continue.go | 7 +- pgproto3/authentication_md5_password.go | 7 +- pgproto3/authentication_ok.go | 7 +- pgproto3/authentication_sasl.go | 10 +- pgproto3/authentication_sasl_continue.go | 12 +- pgproto3/authentication_sasl_final.go | 12 +- pgproto3/backend.go | 25 +++- pgproto3/backend_key_data.go | 7 +- pgproto3/backend_test.go | 4 +- pgproto3/bind.go | 10 +- pgproto3/bind_complete.go | 4 +- pgproto3/bind_test.go | 20 +++ pgproto3/cancel_request.go | 4 +- pgproto3/close.go | 14 +- pgproto3/close_complete.go | 4 +- pgproto3/command_complete.go | 14 +- pgproto3/copy_both_response.go | 10 +- pgproto3/copy_both_response_test.go | 4 +- pgproto3/copy_data.go | 9 +- pgproto3/copy_done.go | 4 +- pgproto3/copy_fail.go | 14 +- pgproto3/copy_in_response.go | 10 +- pgproto3/copy_out_response.go | 10 +- pgproto3/data_row.go | 10 +- pgproto3/describe.go | 14 +- pgproto3/empty_query_response.go | 4 +- pgproto3/error_response.go | 135 ++++++++--------- pgproto3/example/pgfortune/server.go | 21 ++- pgproto3/execute.go | 13 +- pgproto3/flush.go | 4 +- pgproto3/frontend.go | 137 ++++++++++++++---- pgproto3/function_call.go | 9 +- pgproto3/function_call_response.go | 10 +- pgproto3/function_call_test.go | 7 +- pgproto3/gss_enc_request.go | 4 +- pgproto3/gss_response.go | 9 +- pgproto3/no_data.go | 4 +- pgproto3/notice_response.go | 6 +- pgproto3/notification_response.go | 12 +- pgproto3/parameter_description.go | 10 +- pgproto3/parameter_status.go | 14 +- pgproto3/parse.go | 10 +- pgproto3/parse_complete.go | 4 +- pgproto3/password_message.go | 11 +- pgproto3/pgproto3.go | 28 +++- pgproto3/pgproto3_private_test.go | 3 + pgproto3/portal_suspended.go | 4 +- pgproto3/query.go | 11 +- pgproto3/query_test.go | 20 +++ pgproto3/ready_for_query.go | 4 +- pgproto3/row_description.go | 10 +- pgproto3/sasl_initial_response.go | 10 +- pgproto3/sasl_response.go | 11 +- pgproto3/ssl_request.go | 4 +- pgproto3/startup_message.go | 6 +- pgproto3/sync.go | 4 +- pgproto3/terminate.go | 4 +- 61 files changed, 472 insertions(+), 390 deletions(-) create mode 100644 pgproto3/bind_test.go create mode 100644 pgproto3/pgproto3_private_test.go create mode 100644 pgproto3/query_test.go diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index ad81ec60..0bf03f33 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1674,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte + err error } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + if batch.err != nil { + return + } batch.ExecPrepared("", paramValues, paramFormats, resultFormats) } // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) - batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) - batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } } // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing // multiple queries in a single round trip than using pipeline mode. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + if batch.err != nil { + return &MultiResultReader{ + closed: true, + err: batch.err, + } + } + if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, @@ -1718,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.contextWatcher.Watch(ctx) } - batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) + if batch.err != nil { + multiResult.closed = true + multiResult.err = batch.err + pgConn.unlock() + return multiResult + } pgConn.enterPotentialWriteReadDeadlock() defer pgConn.exitPotentialWriteReadDeadlock() diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index f04fa79a..b77d21c1 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3363,9 +3363,9 @@ func TestSNISupport(t *testing.T) { return } - srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil)) - srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)) - srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) + srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))) serverSNINameChan <- sniHost }() @@ -3472,3 +3472,10 @@ func TestFatalErrorReceivedInPipelineMode(t *testing.T) { err = pipeline.Close() require.Error(t, err) } + +func mustEncode(buf []byte, err error) []byte { + if err != nil { + panic(err) + } + return buf +} diff --git a/pgproto3/authentication_cleartext_password.go b/pgproto3/authentication_cleartext_password.go index d8f98b9a..ac2962e9 100644 --- a/pgproto3/authentication_cleartext_password.go +++ b/pgproto3/authentication_cleartext_password.go @@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 8) +func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_gss.go b/pgproto3/authentication_gss.go index 0d234222..178ef31d 100644 --- a/pgproto3/authentication_gss.go +++ b/pgproto3/authentication_gss.go @@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error { return nil } -func (a *AuthenticationGSS) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 4) +func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeGSS) - return dst + return finishMessage(dst, sp) } func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/authentication_gss_continue.go b/pgproto3/authentication_gss_continue.go index 63789dc1..2ba3f3b3 100644 --- a/pgproto3/authentication_gss_continue.go +++ b/pgproto3/authentication_gss_continue.go @@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error { return nil } -func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, int32(len(a.Data))+8) +func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeGSSCont) dst = append(dst, a.Data...) - return dst + return finishMessage(dst, sp) } func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/authentication_md5_password.go b/pgproto3/authentication_md5_password.go index 5671c84c..854c6404 100644 --- a/pgproto3/authentication_md5_password.go +++ b/pgproto3/authentication_md5_password.go @@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 12) +func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeMD5Password) dst = append(dst, src.Salt[:]...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_ok.go b/pgproto3/authentication_ok.go index 88d648ae..ec11d39f 100644 --- a/pgproto3/authentication_ok.go +++ b/pgproto3/authentication_ok.go @@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationOk) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 8) +func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeOk) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_sasl.go b/pgproto3/authentication_sasl.go index 59650d4c..e66580f4 100644 --- a/pgproto3/authentication_sasl.go +++ b/pgproto3/authentication_sasl.go @@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASL) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASL) for _, s := range src.AuthMechanisms { @@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte { } dst = append(dst, 0) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_sasl_continue.go b/pgproto3/authentication_sasl_continue.go index 2ce70a47..70fba4a6 100644 --- a/pgproto3/authentication_sasl_continue.go +++ b/pgproto3/authentication_sasl_continue.go @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) - dst = append(dst, src.Data...) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_sasl_final.go b/pgproto3/authentication_sasl_final.go index a38a8b91..84976c2a 100644 --- a/pgproto3/authentication_sasl_final.go +++ b/pgproto3/authentication_sasl_final.go @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) - dst = append(dst, src.Data...) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Unmarshaler. diff --git a/pgproto3/backend.go b/pgproto3/backend.go index efa909c3..d146c338 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -16,7 +16,8 @@ type Backend struct { // before it is actually transmitted (i.e. before Flush). tracer *tracer - wbuf []byte + wbuf []byte + encodeError error // Frontend message flyweights bind Bind @@ -55,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend { return &Backend{cr: cr, w: w} } -// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is -// called. +// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error +// encountered will be returned from Flush. func (b *Backend) Send(msg BackendMessage) { + if b.encodeError != nil { + return + } + prevLen := len(b.wbuf) - b.wbuf = msg.Encode(b.wbuf) + newBuf, err := msg.Encode(b.wbuf) + if err != nil { + b.encodeError = err + return + } + b.wbuf = newBuf + if b.tracer != nil { b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) } @@ -67,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) { // Flush writes any pending messages to the frontend (i.e. the client). func (b *Backend) Flush() error { + if err := b.encodeError; err != nil { + b.encodeError = nil + b.wbuf = b.wbuf[:0] + return &writeError{err: err, safeToRetry: true} + } + n, err := b.w.Write(b.wbuf) const maxLen = 1024 diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go index 12c60817..23f5da67 100644 --- a/pgproto3/backend_key_data.go +++ b/pgproto3/backend_key_data.go @@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *BackendKeyData) Encode(dst []byte) []byte { - dst = append(dst, 'K') - dst = pgio.AppendUint32(dst, 12) +func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'K') dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go index 5655122a..5107ef76 100644 --- a/pgproto3/backend_test.go +++ b/pgproto3/backend_test.go @@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) { "username": "tester", }, } - dst := []byte{} - dst = want.Encode(dst) + dst, err := want.Encode([]byte{}) + require.NoError(t, err) server := &interruptReader{} server.push(dst) diff --git a/pgproto3/bind.go b/pgproto3/bind.go index fdd2d3b8..b32cd81c 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -108,10 +108,8 @@ func (dst *Bind) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Bind) Encode(dst []byte) []byte { - dst = append(dst, 'B') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *Bind) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'B') dst = append(dst, src.DestinationPortal...) dst = append(dst, 0) @@ -139,9 +137,7 @@ func (src *Bind) Encode(dst []byte) []byte { dst = pgio.AppendInt16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go index 3be256c8..bacf30d8 100644 --- a/pgproto3/bind_complete.go +++ b/pgproto3/bind_complete.go @@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *BindComplete) Encode(dst []byte) []byte { - return append(dst, '2', 0, 0, 0, 4) +func (src *BindComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '2', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/bind_test.go b/pgproto3/bind_test.go new file mode 100644 index 00000000..6ec0e024 --- /dev/null +++ b/pgproto3/bind_test.go @@ -0,0 +1,20 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) { + t.Parallel() + + // Maximum allowed size. + _, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil) + require.NoError(t, err) + + // 1 byte too big + _, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil) + require.Error(t, err) +} diff --git a/pgproto3/cancel_request.go b/pgproto3/cancel_request.go index 8fcf8217..6b52dd97 100644 --- a/pgproto3/cancel_request.go +++ b/pgproto3/cancel_request.go @@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *CancelRequest) Encode(dst []byte) []byte { +func (src *CancelRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 16) dst = pgio.AppendInt32(dst, cancelRequestCode) dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/close.go b/pgproto3/close.go index f99b5943..0b50f27c 100644 --- a/pgproto3/close.go +++ b/pgproto3/close.go @@ -4,8 +4,6 @@ import ( "bytes" "encoding/json" "errors" - - "github.com/jackc/pgx/v5/internal/pgio" ) type Close struct { @@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Close) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Close) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go index 1d7b8f08..833f7a12 100644 --- a/pgproto3/close_complete.go +++ b/pgproto3/close_complete.go @@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CloseComplete) Encode(dst []byte) []byte { - return append(dst, '3', 0, 0, 0, 4) +func (src *CloseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '3', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go index 814027ca..eba70947 100644 --- a/pgproto3/command_complete.go +++ b/pgproto3/command_complete.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type CommandComplete struct { @@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CommandComplete) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CommandComplete) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.CommandTag...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index 8840a89e..dbbd8e15 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -44,19 +44,15 @@ func (dst *CopyBothResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyBothResponse) Encode(dst []byte) []byte { - dst = append(dst, 'W') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'W') dst = append(dst, src.OverallFormat) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_both_response_test.go b/pgproto3/copy_both_response_test.go index 4437de1d..1c988f21 100644 --- a/pgproto3/copy_both_response_test.go +++ b/pgproto3/copy_both_response_test.go @@ -5,6 +5,7 @@ import ( "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEncodeDecode(t *testing.T) { @@ -13,6 +14,7 @@ func TestEncodeDecode(t *testing.T) { err := dstResp.Decode(srcBytes[5:]) assert.NoError(t, err, "No errors on decode") dstBytes := []byte{} - dstBytes = dstResp.Encode(dstBytes) + dstBytes, err = dstResp.Encode(dstBytes) + require.NoError(t, err) assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match") } diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go index 59e3dd94..89ecdd4d 100644 --- a/pgproto3/copy_data.go +++ b/pgproto3/copy_data.go @@ -3,8 +3,6 @@ package pgproto3 import ( "encoding/hex" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type CopyData struct { @@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyData) Encode(dst []byte) []byte { - dst = append(dst, 'd') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) +func (src *CopyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'd') dst = append(dst, src.Data...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_done.go b/pgproto3/copy_done.go index 0e13282b..040814db 100644 --- a/pgproto3/copy_done.go +++ b/pgproto3/copy_done.go @@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyDone) Encode(dst []byte) []byte { - return append(dst, 'c', 0, 0, 0, 4) +func (src *CopyDone) Encode(dst []byte) ([]byte, error) { + return append(dst, 'c', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_fail.go b/pgproto3/copy_fail.go index 0041bbb1..72a85fd0 100644 --- a/pgproto3/copy_fail.go +++ b/pgproto3/copy_fail.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type CopyFail struct { @@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyFail) Encode(dst []byte) []byte { - dst = append(dst, 'f') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CopyFail) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'f') dst = append(dst, src.Message...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index 4584f7df..0a772afa 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -44,10 +44,8 @@ func (dst *CopyInResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyInResponse) Encode(dst []byte) []byte { - dst = append(dst, 'G') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'G') dst = append(dst, src.OverallFormat) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) @@ -55,9 +53,7 @@ func (src *CopyInResponse) Encode(dst []byte) []byte { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index 3175c6a4..40525da6 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -43,10 +43,8 @@ func (dst *CopyOutResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyOutResponse) Encode(dst []byte) []byte { - dst = append(dst, 'H') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'H') dst = append(dst, src.OverallFormat) @@ -55,9 +53,7 @@ func (src *CopyOutResponse) Encode(dst []byte) []byte { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index 4de77977..cbc76dc2 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -63,10 +63,8 @@ func (dst *DataRow) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *DataRow) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *DataRow) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') dst = pgio.AppendUint16(dst, uint16(len(src.Values))) for _, v := range src.Values { @@ -79,9 +77,7 @@ func (src *DataRow) Encode(dst []byte) []byte { dst = append(dst, v...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/describe.go b/pgproto3/describe.go index f131d1f4..89feff21 100644 --- a/pgproto3/describe.go +++ b/pgproto3/describe.go @@ -4,8 +4,6 @@ import ( "bytes" "encoding/json" "errors" - - "github.com/jackc/pgx/v5/internal/pgio" ) type Describe struct { @@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Describe) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Describe) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go index 2b85e744..cb6cca07 100644 --- a/pgproto3/empty_query_response.go +++ b/pgproto3/empty_query_response.go @@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *EmptyQueryResponse) Encode(dst []byte) []byte { - return append(dst, 'I', 0, 0, 0, 4) +func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) { + return append(dst, 'I', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go index 45c9a981..6ef9bd06 100644 --- a/pgproto3/error_response.go +++ b/pgproto3/error_response.go @@ -2,7 +2,6 @@ package pgproto3 import ( "bytes" - "encoding/binary" "encoding/json" "strconv" ) @@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ErrorResponse) Encode(dst []byte) []byte { - return append(dst, src.marshalBinary('E')...) +func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') + dst = src.appendFields(dst) + return finishMessage(dst, sp) } -func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte(typeByte) - buf.Write(bigEndian.Uint32(0)) - +func (src *ErrorResponse) appendFields(dst []byte) []byte { if src.Severity != "" { - buf.WriteByte('S') - buf.WriteString(src.Severity) - buf.WriteByte(0) + dst = append(dst, 'S') + dst = append(dst, src.Severity...) + dst = append(dst, 0) } if src.SeverityUnlocalized != "" { - buf.WriteByte('V') - buf.WriteString(src.SeverityUnlocalized) - buf.WriteByte(0) + dst = append(dst, 'V') + dst = append(dst, src.SeverityUnlocalized...) + dst = append(dst, 0) } if src.Code != "" { - buf.WriteByte('C') - buf.WriteString(src.Code) - buf.WriteByte(0) + dst = append(dst, 'C') + dst = append(dst, src.Code...) + dst = append(dst, 0) } if src.Message != "" { - buf.WriteByte('M') - buf.WriteString(src.Message) - buf.WriteByte(0) + dst = append(dst, 'M') + dst = append(dst, src.Message...) + dst = append(dst, 0) } if src.Detail != "" { - buf.WriteByte('D') - buf.WriteString(src.Detail) - buf.WriteByte(0) + dst = append(dst, 'D') + dst = append(dst, src.Detail...) + dst = append(dst, 0) } if src.Hint != "" { - buf.WriteByte('H') - buf.WriteString(src.Hint) - buf.WriteByte(0) + dst = append(dst, 'H') + dst = append(dst, src.Hint...) + dst = append(dst, 0) } if src.Position != 0 { - buf.WriteByte('P') - buf.WriteString(strconv.Itoa(int(src.Position))) - buf.WriteByte(0) + dst = append(dst, 'P') + dst = append(dst, strconv.Itoa(int(src.Position))...) + dst = append(dst, 0) } if src.InternalPosition != 0 { - buf.WriteByte('p') - buf.WriteString(strconv.Itoa(int(src.InternalPosition))) - buf.WriteByte(0) + dst = append(dst, 'p') + dst = append(dst, strconv.Itoa(int(src.InternalPosition))...) + dst = append(dst, 0) } if src.InternalQuery != "" { - buf.WriteByte('q') - buf.WriteString(src.InternalQuery) - buf.WriteByte(0) + dst = append(dst, 'q') + dst = append(dst, src.InternalQuery...) + dst = append(dst, 0) } if src.Where != "" { - buf.WriteByte('W') - buf.WriteString(src.Where) - buf.WriteByte(0) + dst = append(dst, 'W') + dst = append(dst, src.Where...) + dst = append(dst, 0) } if src.SchemaName != "" { - buf.WriteByte('s') - buf.WriteString(src.SchemaName) - buf.WriteByte(0) + dst = append(dst, 's') + dst = append(dst, src.SchemaName...) + dst = append(dst, 0) } if src.TableName != "" { - buf.WriteByte('t') - buf.WriteString(src.TableName) - buf.WriteByte(0) + dst = append(dst, 't') + dst = append(dst, src.TableName...) + dst = append(dst, 0) } if src.ColumnName != "" { - buf.WriteByte('c') - buf.WriteString(src.ColumnName) - buf.WriteByte(0) + dst = append(dst, 'c') + dst = append(dst, src.ColumnName...) + dst = append(dst, 0) } if src.DataTypeName != "" { - buf.WriteByte('d') - buf.WriteString(src.DataTypeName) - buf.WriteByte(0) + dst = append(dst, 'd') + dst = append(dst, src.DataTypeName...) + dst = append(dst, 0) } if src.ConstraintName != "" { - buf.WriteByte('n') - buf.WriteString(src.ConstraintName) - buf.WriteByte(0) + dst = append(dst, 'n') + dst = append(dst, src.ConstraintName...) + dst = append(dst, 0) } if src.File != "" { - buf.WriteByte('F') - buf.WriteString(src.File) - buf.WriteByte(0) + dst = append(dst, 'F') + dst = append(dst, src.File...) + dst = append(dst, 0) } if src.Line != 0 { - buf.WriteByte('L') - buf.WriteString(strconv.Itoa(int(src.Line))) - buf.WriteByte(0) + dst = append(dst, 'L') + dst = append(dst, strconv.Itoa(int(src.Line))...) + dst = append(dst, 0) } if src.Routine != "" { - buf.WriteByte('R') - buf.WriteString(src.Routine) - buf.WriteByte(0) + dst = append(dst, 'R') + dst = append(dst, src.Routine...) + dst = append(dst, 0) } for k, v := range src.UnknownFields { - buf.WriteByte(k) - buf.WriteString(v) - buf.WriteByte(0) + dst = append(dst, k) + dst = append(dst, v...) + dst = append(dst, 0) } - buf.WriteByte(0) + dst = append(dst, 0) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - - return buf.Bytes() + return dst } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/example/pgfortune/server.go b/pgproto3/example/pgfortune/server.go index 14ae71f8..06a45dda 100644 --- a/pgproto3/example/pgfortune/server.go +++ b/pgproto3/example/pgfortune/server.go @@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error { return fmt.Errorf("error generating query response: %w", err) } - buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ { Name: []byte("fortune"), TableOID: 0, @@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error { TypeModifier: -1, Format: 0, }, - }}).Encode(nil) - buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf) - buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf) - buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + }}).Encode(nil)) + buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)) + buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)) + buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)) _, err = p.conn.Write(buf) if err != nil { return fmt.Errorf("error writing query response: %w", err) @@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error { switch startupMessage.(type) { case *pgproto3.StartupMessage: - buf := (&pgproto3.AuthenticationOk{}).Encode(nil) - buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)) + buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)) _, err = p.conn.Write(buf) if err != nil { return fmt.Errorf("error sending ready for query: %w", err) @@ -102,3 +102,10 @@ func (p *PgFortuneBackend) handleStartup() error { func (p *PgFortuneBackend) Close() error { return p.conn.Close() } + +func mustEncode(buf []byte, err error) []byte { + if err != nil { + panic(err) + } + return buf +} diff --git a/pgproto3/execute.go b/pgproto3/execute.go index a5fee7cb..31bc714d 100644 --- a/pgproto3/execute.go +++ b/pgproto3/execute.go @@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Execute) Encode(dst []byte) []byte { - dst = append(dst, 'E') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Execute) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') dst = append(dst, src.Portal...) dst = append(dst, 0) - dst = pgio.AppendUint32(dst, src.MaxRows) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/flush.go b/pgproto3/flush.go index 2725f689..e5dc1fbb 100644 --- a/pgproto3/flush.go +++ b/pgproto3/flush.go @@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Flush) Encode(dst []byte) []byte { - return append(dst, 'H', 0, 0, 0, 4) +func (src *Flush) Encode(dst []byte) ([]byte, error) { + return append(dst, 'H', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 60c34ef0..b41abbe1 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -18,7 +18,8 @@ type Frontend struct { // idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq. tracer *tracer - wbuf []byte + wbuf []byte + encodeError error // Backend message flyweights authenticationOk AuthenticationOk @@ -64,16 +65,26 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend { return &Frontend{cr: cr, w: w} } -// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is -// called. +// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error +// encountered will be returned from Flush. // // Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods // such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an // extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden // behind an interface. func (f *Frontend) Send(msg FrontendMessage) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) } @@ -81,6 +92,12 @@ func (f *Frontend) Send(msg FrontendMessage) { // Flush writes any pending messages to the backend (i.e. the server). func (f *Frontend) Flush() error { + if err := f.encodeError; err != nil { + f.encodeError = nil + f.wbuf = f.wbuf[:0] + return &writeError{err: err, safeToRetry: true} + } + if len(f.wbuf) == 0 { return nil } @@ -116,71 +133,141 @@ func (f *Frontend) Untrace() { f.tracer = nil } -// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendBind(msg *Bind) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendParse(msg *Parse) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendClose(msg *Close) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is +// called. Any error encountered will be returned from Flush. func (f *Frontend) SendDescribe(msg *Describe) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendExecute sends an Execute message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called. +// Any error encountered will be returned from Flush. func (f *Frontend) SendExecute(msg *Execute) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendSync(msg *Sync) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendQuery(msg *Query) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) } diff --git a/pgproto3/function_call.go b/pgproto3/function_call.go index 2c4f38df..0b15fce2 100644 --- a/pgproto3/function_call.go +++ b/pgproto3/function_call.go @@ -71,10 +71,8 @@ func (dst *FunctionCall) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *FunctionCall) Encode(dst []byte) []byte { - dst = append(dst, 'F') - sp := len(dst) - dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end +func (src *FunctionCall) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'F') dst = pgio.AppendUint32(dst, src.Function) dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) for _, argFormatCode := range src.ArgFormatCodes { @@ -90,6 +88,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte { } } dst = pgio.AppendUint16(dst, src.ResultFormatCode) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return dst + return finishMessage(dst, sp) } diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go index 3d3606dd..1f273495 100644 --- a/pgproto3/function_call_response.go +++ b/pgproto3/function_call_response.go @@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *FunctionCallResponse) Encode(dst []byte) []byte { - dst = append(dst, 'V') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'V') if src.Result == nil { dst = pgio.AppendInt32(dst, -1) @@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte { dst = append(dst, src.Result...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/function_call_test.go b/pgproto3/function_call_test.go index 8c08bb24..2a70fd30 100644 --- a/pgproto3/function_call_test.go +++ b/pgproto3/function_call_test.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestFunctionCall_EncodeDecode(t *testing.T) { @@ -30,7 +32,8 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { Arguments: tt.fields.Arguments, ResultFormatCode: tt.fields.ResultFormatCode, } - encoded := src.Encode([]byte{}) + encoded, err := src.Encode([]byte{}) + require.NoError(t, err) dst := &FunctionCall{} // Check the header msgTypeCode := encoded[0] @@ -44,7 +47,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded)) } // Check decoding works as expected - err := dst.Decode(encoded[5:]) + err = dst.Decode(encoded[5:]) if err != nil { if !tt.wantErr { t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) diff --git a/pgproto3/gss_enc_request.go b/pgproto3/gss_enc_request.go index 30ffc08d..70cb20cd 100644 --- a/pgproto3/gss_enc_request.go +++ b/pgproto3/gss_enc_request.go @@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *GSSEncRequest) Encode(dst []byte) []byte { +func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, gssEncReqNumber) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/gss_response.go b/pgproto3/gss_response.go index 64bfbd04..10d93775 100644 --- a/pgproto3/gss_response.go +++ b/pgproto3/gss_response.go @@ -2,8 +2,6 @@ package pgproto3 import ( "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type GSSResponse struct { @@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error { return nil } -func (g *GSSResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(g.Data))) +func (g *GSSResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, g.Data...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go index d8f85d38..cbcaad40 100644 --- a/pgproto3/no_data.go +++ b/pgproto3/no_data.go @@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NoData) Encode(dst []byte) []byte { - return append(dst, 'n', 0, 0, 0, 4) +func (src *NoData) Encode(dst []byte) ([]byte, error) { + return append(dst, 'n', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go index 4ac28a79..497aba6d 100644 --- a/pgproto3/notice_response.go +++ b/pgproto3/notice_response.go @@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NoticeResponse) Encode(dst []byte) []byte { - return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) +func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'N') + dst = (*ErrorResponse)(src).appendFields(dst) + return finishMessage(dst, sp) } diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go index 228e0dac..243b6bf7 100644 --- a/pgproto3/notification_response.go +++ b/pgproto3/notification_response.go @@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NotificationResponse) Encode(dst []byte) []byte { - dst = append(dst, 'A') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'A') dst = pgio.AppendUint32(dst, src.PID) dst = append(dst, src.Channel...) dst = append(dst, 0) dst = append(dst, src.Payload...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go index 374d38a3..685e04b8 100644 --- a/pgproto3/parameter_description.go +++ b/pgproto3/parameter_description.go @@ -39,19 +39,15 @@ func (dst *ParameterDescription) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParameterDescription) Encode(dst []byte) []byte { - dst = append(dst, 't') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 't') dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go index a303e453..9ee0720b 100644 --- a/pgproto3/parameter_status.go +++ b/pgproto3/parameter_status.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type ParameterStatus struct { @@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParameterStatus) Encode(dst []byte) []byte { - dst = append(dst, 'S') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'S') dst = append(dst, src.Name...) dst = append(dst, 0) dst = append(dst, src.Value...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parse.go b/pgproto3/parse.go index b53200dc..a59154cd 100644 --- a/pgproto3/parse.go +++ b/pgproto3/parse.go @@ -52,10 +52,8 @@ func (dst *Parse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Parse) Encode(dst []byte) []byte { - dst = append(dst, 'P') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *Parse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'P') dst = append(dst, src.Name...) dst = append(dst, 0) @@ -67,9 +65,7 @@ func (src *Parse) Encode(dst []byte) []byte { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go index 92c9498b..cff9e27d 100644 --- a/pgproto3/parse_complete.go +++ b/pgproto3/parse_complete.go @@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParseComplete) Encode(dst []byte) []byte { - return append(dst, '1', 0, 0, 0, 4) +func (src *ParseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '1', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go index 41f98692..d820d327 100644 --- a/pgproto3/password_message.go +++ b/pgproto3/password_message.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type PasswordMessage struct { @@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *PasswordMessage) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) - +func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Password...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index 8df383c2..480abfc0 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -4,8 +4,14 @@ import ( "encoding/hex" "errors" "fmt" + + "github.com/jackc/pgx/v5/internal/pgio" ) +// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL +// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff. +const maxMessageBodyLen = (0x3fffffff - 1) + // Message is the interface implemented by an object that can decode and encode // a particular PostgreSQL message. type Message interface { @@ -14,7 +20,7 @@ type Message interface { Decode(data []byte) error // Encode appends itself to dst and returns the new buffer. - Encode(dst []byte) []byte + Encode(dst []byte) ([]byte, error) } // FrontendMessage is a message sent by the frontend (i.e. the client). @@ -92,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) { } return nil, errors.New("unknown protocol representation") } + +// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to +// dst. It returns the new buffer and the position of the message length placeholder. +func beginMessage(dst []byte, t byte) ([]byte, int) { + dst = append(dst, t) + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + return dst, sp +} + +// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to +// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer. +func finishMessage(dst []byte, sp int) ([]byte, error) { + messageBodyLen := len(dst[sp:]) + if messageBodyLen > maxMessageBodyLen { + return nil, errors.New("message body too large") + } + pgio.SetInt32(dst[sp:], int32(messageBodyLen)) + return dst, nil +} diff --git a/pgproto3/pgproto3_private_test.go b/pgproto3/pgproto3_private_test.go new file mode 100644 index 00000000..15da1eaf --- /dev/null +++ b/pgproto3/pgproto3_private_test.go @@ -0,0 +1,3 @@ +package pgproto3 + +const MaxMessageBodyLen = maxMessageBodyLen diff --git a/pgproto3/portal_suspended.go b/pgproto3/portal_suspended.go index 1a9e7bfb..9e2f8cbc 100644 --- a/pgproto3/portal_suspended.go +++ b/pgproto3/portal_suspended.go @@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *PortalSuspended) Encode(dst []byte) []byte { - return append(dst, 's', 0, 0, 0, 4) +func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) { + return append(dst, 's', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/query.go b/pgproto3/query.go index e963a0ec..aebdfde8 100644 --- a/pgproto3/query.go +++ b/pgproto3/query.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type Query struct { @@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Query) Encode(dst []byte) []byte { - dst = append(dst, 'Q') - dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) - +func (src *Query) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'Q') dst = append(dst, src.String...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/query_test.go b/pgproto3/query_test.go new file mode 100644 index 00000000..9551fc14 --- /dev/null +++ b/pgproto3/query_test.go @@ -0,0 +1,20 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestQueryBiggerThanMaxMessageBodyLen(t *testing.T) { + t.Parallel() + + // Maximum allowed size. 4 bytes for size and 1 byte for 0 terminated string. + _, err := (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-5))}).Encode(nil) + require.NoError(t, err) + + // 1 byte too big + _, err = (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-4))}).Encode(nil) + require.Error(t, err) +} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go index 67a39be3..a56af9fb 100644 --- a/pgproto3/ready_for_query.go +++ b/pgproto3/ready_for_query.go @@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ReadyForQuery) Encode(dst []byte) []byte { - return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) +func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index 6f6f0681..c68f1d46 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -99,10 +99,8 @@ func (dst *RowDescription) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *RowDescription) Encode(dst []byte) []byte { - dst = append(dst, 'T') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *RowDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'T') dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) for _, fd := range src.Fields { @@ -117,9 +115,7 @@ func (src *RowDescription) Encode(dst []byte) []byte { dst = pgio.AppendInt16(dst, fd.Format) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/sasl_initial_response.go b/pgproto3/sasl_initial_response.go index eeda4691..9eb1b6a4 100644 --- a/pgproto3/sasl_initial_response.go +++ b/pgproto3/sasl_initial_response.go @@ -39,10 +39,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *SASLInitialResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, []byte(src.AuthMechanism)...) dst = append(dst, 0) @@ -50,9 +48,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte { dst = pgio.AppendInt32(dst, int32(len(src.Data))) dst = append(dst, src.Data...) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/sasl_response.go b/pgproto3/sasl_response.go index 54c3d96f..1b604c25 100644 --- a/pgproto3/sasl_response.go +++ b/pgproto3/sasl_response.go @@ -3,8 +3,6 @@ package pgproto3 import ( "encoding/hex" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type SASLResponse struct { @@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *SASLResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) - +func (src *SASLResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Data...) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/ssl_request.go b/pgproto3/ssl_request.go index 1b00c16b..b0fc2847 100644 --- a/pgproto3/ssl_request.go +++ b/pgproto3/ssl_request.go @@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *SSLRequest) Encode(dst []byte) []byte { +func (src *SSLRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, sslRequestNumber) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index 65de4a36..3af4587d 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *StartupMessage) Encode(dst []byte) []byte { +func (src *StartupMessage) Encode(dst []byte) ([]byte, error) { sp := len(dst) dst = pgio.AppendInt32(dst, -1) @@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte { } dst = append(dst, 0) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/sync.go b/pgproto3/sync.go index 5db8e07a..ea4fc959 100644 --- a/pgproto3/sync.go +++ b/pgproto3/sync.go @@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Sync) Encode(dst []byte) []byte { - return append(dst, 'S', 0, 0, 0, 4) +func (src *Sync) Encode(dst []byte) ([]byte, error) { + return append(dst, 'S', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go index 135191ea..35a9dc83 100644 --- a/pgproto3/terminate.go +++ b/pgproto3/terminate.go @@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Terminate) Encode(dst []byte) []byte { - return append(dst, 'X', 0, 0, 0, 4) +func (src *Terminate) Encode(dst []byte) ([]byte, error) { + return append(dst, 'X', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. From 20344dfae83e672eff73a03324398543a1d87f43 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Mar 2024 11:56:44 -0600 Subject: [PATCH 26/38] Check for overflow on uint16 sizes in pgproto3 --- pgproto3/bind.go | 11 +++++++++++ pgproto3/copy_both_response.go | 4 ++++ pgproto3/copy_in_response.go | 4 ++++ pgproto3/copy_out_response.go | 4 ++++ pgproto3/data_row.go | 5 +++++ pgproto3/function_call.go | 10 ++++++++++ pgproto3/parameter_description.go | 5 +++++ pgproto3/parse.go | 5 +++++ pgproto3/row_description.go | 5 +++++ 9 files changed, 53 insertions(+) diff --git a/pgproto3/bind.go b/pgproto3/bind.go index b32cd81c..ad6ac48b 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -5,7 +5,9 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" "fmt" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -116,11 +118,17 @@ func (src *Bind) Encode(dst []byte) ([]byte, error) { dst = append(dst, src.PreparedStatement...) dst = append(dst, 0) + if len(src.ParameterFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many parameter format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) for _, fc := range src.ParameterFormatCodes { dst = pgio.AppendInt16(dst, fc) } + if len(src.Parameters) > math.MaxUint16 { + return nil, errors.New("too many parameters") + } dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) for _, p := range src.Parameters { if p == nil { @@ -132,6 +140,9 @@ func (src *Bind) Encode(dst []byte) ([]byte, error) { dst = append(dst, p...) } + if len(src.ResultFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many result format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) for _, fc := range src.ResultFormatCodes { dst = pgio.AppendInt16(dst, fc) diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index dbbd8e15..99e1afea 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -47,6 +48,9 @@ func (dst *CopyBothResponse) Decode(src []byte) error { func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'W') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index 0a772afa..06cf99ce 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -48,6 +49,9 @@ func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'G') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index 40525da6..549e916c 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -48,6 +49,9 @@ func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) { dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index cbc76dc2..fdfb0f7f 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -66,6 +68,9 @@ func (dst *DataRow) Decode(src []byte) error { func (src *DataRow) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'D') + if len(src.Values) > math.MaxUint16 { + return nil, errors.New("too many values") + } dst = pgio.AppendUint16(dst, uint16(len(src.Values))) for _, v := range src.Values { if v == nil { diff --git a/pgproto3/function_call.go b/pgproto3/function_call.go index 0b15fce2..7d83579f 100644 --- a/pgproto3/function_call.go +++ b/pgproto3/function_call.go @@ -2,6 +2,8 @@ package pgproto3 import ( "encoding/binary" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -74,10 +76,18 @@ func (dst *FunctionCall) Decode(src []byte) error { func (src *FunctionCall) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'F') dst = pgio.AppendUint32(dst, src.Function) + + if len(src.ArgFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many arg format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) for _, argFormatCode := range src.ArgFormatCodes { dst = pgio.AppendUint16(dst, argFormatCode) } + + if len(src.Arguments) > math.MaxUint16 { + return nil, errors.New("too many arguments") + } dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) for _, argument := range src.Arguments { if argument == nil { diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go index 685e04b8..1ef27b75 100644 --- a/pgproto3/parameter_description.go +++ b/pgproto3/parameter_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -42,6 +44,9 @@ func (dst *ParameterDescription) Decode(src []byte) error { func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 't') + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) diff --git a/pgproto3/parse.go b/pgproto3/parse.go index a59154cd..6ba3486c 100644 --- a/pgproto3/parse.go +++ b/pgproto3/parse.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -60,6 +62,9 @@ func (src *Parse) Encode(dst []byte) ([]byte, error) { dst = append(dst, src.Query...) dst = append(dst, 0) + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index c68f1d46..dc2a4ddf 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -102,6 +104,9 @@ func (dst *RowDescription) Decode(src []byte) error { func (src *RowDescription) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'T') + if len(src.Fields) > math.MaxUint16 { + return nil, errors.New("too many fields") + } dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) for _, fd := range src.Fields { dst = append(dst, fd.Name...) From c543134753a0c5d22881c12404025724cb05ffd8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 4 Mar 2024 09:05:32 -0600 Subject: [PATCH 27/38] SQL sanitizer wraps arguments in parentheses pgx v5 was not vulnerable to CVE-2024-27289 do to how the sanitizer was being called. But the sanitizer itself still had the underlying issue. This commit ports the fix from pgx v4 to v5 to ensure that the issue does not emerge if pgx uses the sanitizer differently in the future. --- internal/sanitize/sanitize.go | 4 ++++ internal/sanitize/sanitize_test.go | 28 +++++++++++++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index f9091cd4..08d24fe4 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -63,6 +63,10 @@ func (q *Query) Sanitize(args ...any) (string, error) { return "", fmt.Errorf("invalid arg type: %T", arg) } argUse[argIdx] = true + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + str = "(" + str + ")" default: return "", fmt.Errorf("invalid Part type: %T", part) } diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index e2533aab..191bf1e9 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -132,47 +132,57 @@ func TestQuerySanitize(t *testing.T) { { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{int64(42)}, - expected: `select 42`, + expected: `select (42)`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{float64(1.23)}, - expected: `select 1.23`, + expected: `select (1.23)`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{true}, - expected: `select true`, + expected: `select (true)`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{[]byte{0, 1, 2, 3, 255}}, - expected: `select '\x00010203ff'`, + expected: `select ('\x00010203ff')`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{nil}, - expected: `select null`, + expected: `select (null)`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{"foobar"}, - expected: `select 'foobar'`, + expected: `select ('foobar')`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{"foo'bar"}, - expected: `select 'foo''bar'`, + expected: `select ('foo''bar')`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{`foo\'bar`}, - expected: `select 'foo\''bar'`, + expected: `select ('foo\''bar')`, }, { query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}}, args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, - expected: `insert '2020-03-01 23:59:59.999999Z'`, + expected: `insert ('2020-03-01 23:59:59.999999Z')`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, + args: []any{int64(-1)}, + expected: `select 1-(-1)`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, + args: []any{float64(-1)}, + expected: `select 1-(-1)`, }, } From da6f2c98f2664b215b40b1606551fdfcc7f3ea5c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 4 Mar 2024 09:12:06 -0600 Subject: [PATCH 28/38] Update changelog --- CHANGELOG.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fcbc247..78de6db7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,22 @@ +# 5.5.4 (March 4, 2024) + +Fix CVE-2024-27304 + +SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer +overflow in the calculated message size can cause the one large message to be sent as multiple messages under the +attacker's control. + +Thanks to Paul Gerste for reporting this issue. + +* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix) +* Fix simple protocol encoding of json.RawMessage +* Fix *Pipeline.getResults should close pipeline on error +* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman) +* Fix deallocation of invalidated cached statements in a transaction +* Handle invalid sslkey file +* Fix scan float4 into sql.Scanner +* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads. + # 5.5.3 (February 3, 2024) * Fix: prepared statement already exists From 0cc4c14e620d484d555f2955fb39cfecea89aaa3 Mon Sep 17 00:00:00 2001 From: Felix <23635466+its-felix@users.noreply.github.com> Date: Tue, 5 Mar 2024 04:07:44 +0100 Subject: [PATCH 29/38] Add test to validate CollectRows for empty Rows https://github.com/jackc/pgx/issues/1924 https://github.com/jackc/pgx/issues/1925 --- rows_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/rows_test.go b/rows_test.go index 31bd8c83..bb9d5015 100644 --- a/rows_test.go +++ b/rows_test.go @@ -175,6 +175,21 @@ func TestCollectRows(t *testing.T) { }) } +func TestCollectRowsEmpty(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(1, 0) n`) + numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + require.NotNil(t, numbers) + + assert.Empty(t, numbers) + }) +} + // This example uses CollectRows with a manually written collector function. In most cases RowTo, RowToAddrOf, // RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used. func ExampleCollectRows() { From 49b6aad319f125cd3016b1c00db015d6ca8772db Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Mar 2024 12:06:58 -0600 Subject: [PATCH 30/38] Use spaces instead of parentheses for SQL sanitization This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as `set foo to $1` where the substition is taking place in a location where an arbitrary expression is not allowed. https://github.com/jackc/pgx/issues/1928 --- internal/sanitize/sanitize.go | 2 +- internal/sanitize/sanitize_test.go | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 08d24fe4..df58c448 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -66,7 +66,7 @@ func (q *Query) Sanitize(args ...any) (string, error) { // Prevent SQL injection via Line Comment Creation // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p - str = "(" + str + ")" + str = " " + str + " " default: return "", fmt.Errorf("invalid Part type: %T", part) } diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index 191bf1e9..1deff3fb 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -132,57 +132,57 @@ func TestQuerySanitize(t *testing.T) { { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{int64(42)}, - expected: `select (42)`, + expected: `select 42 `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{float64(1.23)}, - expected: `select (1.23)`, + expected: `select 1.23 `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{true}, - expected: `select (true)`, + expected: `select true `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{[]byte{0, 1, 2, 3, 255}}, - expected: `select ('\x00010203ff')`, + expected: `select '\x00010203ff' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{nil}, - expected: `select (null)`, + expected: `select null `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{"foobar"}, - expected: `select ('foobar')`, + expected: `select 'foobar' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{"foo'bar"}, - expected: `select ('foo''bar')`, + expected: `select 'foo''bar' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{`foo\'bar`}, - expected: `select ('foo\''bar')`, + expected: `select 'foo\''bar' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}}, args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, - expected: `insert ('2020-03-01 23:59:59.999999Z')`, + expected: `insert '2020-03-01 23:59:59.999999Z' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, args: []any{int64(-1)}, - expected: `select 1-(-1)`, + expected: `select 1- -1 `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, args: []any{float64(-1)}, - expected: `select 1-(-1)`, + expected: `select 1- -1 `, }, } From a17f064492d5e560304aefde2784ef9253f1d0ce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Mar 2024 12:12:41 -0600 Subject: [PATCH 31/38] Update changelog --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 78de6db7..e538255d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 5.5.5 (March 9, 2024) + +Use spaces instead of parentheses for SQL sanitization. + +This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as +`set foo to $1` where the substition is taking place in a location where an arbitrary expression is not allowed. + # 5.5.4 (March 4, 2024) Fix CVE-2024-27304 From 78a0a2bf410b28c08359fc8c7550c1200589521c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Mar 2024 12:16:20 -0600 Subject: [PATCH 32/38] Fix spelling in changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e538255d..5f780fdc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ Use spaces instead of parentheses for SQL sanitization. This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as -`set foo to $1` where the substition is taking place in a location where an arbitrary expression is not allowed. +`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed. # 5.5.4 (March 4, 2024) From 7fd6f2a4f5cebb35f47471b6806378cc3d862cd8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Mar 2024 09:23:13 -0500 Subject: [PATCH 33/38] Disable parallel testing on Github Actions CI Tests were failing with: Error: Process completed with exit code 143. This appears to mean that Github Actions killed the runner. See https://github.com/jackc/pgx/actions/runs/8216337993/job/22470808811 for an example. It appears Github Actions kills runners based on resource usage. Running tests one at a time reduces the resource usage and avoids the problem. Or at least that's what I presume is happening. It sure is fun debugging issues on cloud systems where you have limited visibility... :( fixes https://github.com/jackc/pgx/issues/1934 --- .github/workflows/ci.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1e494fdd..2776206c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -106,7 +106,8 @@ jobs: git diff --exit-code - name: Test - run: go test -v -race ./... + # parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner. + run: go test -parallel=1 -v -race ./... env: PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }} PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} @@ -149,6 +150,7 @@ jobs: shell: bash - name: Test - run: go test -v -race ./... + # parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner. + run: go test -parallel=1 -v -race ./... env: PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }} From c1fce377ee9fd79df7e855f4b7e0c07a4bbf58fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Mar 2024 09:44:23 -0500 Subject: [PATCH 34/38] Test Go 1.22 and drop Go 1.20 from testing on CI --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2776206c..a7a37d1a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: - go-version: ["1.20", "1.21"] + go-version: ["1.21", "1.22"] pg-version: [12, 13, 14, 15, 16, cockroachdb] include: - pg-version: 12 @@ -125,7 +125,7 @@ jobs: runs-on: windows-latest strategy: matrix: - go-version: ["1.20", "1.21"] + go-version: ["1.21", "1.22"] steps: - name: Setup PostgreSQL From 1b6227af11e9e84787d18c4c2bff730a4900ba09 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Mar 2024 09:52:50 -0500 Subject: [PATCH 35/38] Remove verbose flag from go test command on CI It is more often that interesting information is buried by the verbose output than the verbose output is useful. It can be reenabled later if necessary. --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a7a37d1a..47ed2448 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -107,7 +107,7 @@ jobs: - name: Test # parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner. - run: go test -parallel=1 -v -race ./... + run: go test -parallel=1 -race ./... env: PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }} PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} @@ -151,6 +151,6 @@ jobs: - name: Test # parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner. - run: go test -parallel=1 -v -race ./... + run: go test -parallel=1 -race ./... env: PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }} From b6e55483416ce30d1476317a1d2246026164f016 Mon Sep 17 00:00:00 2001 From: Tomas Zahradnicek Date: Wed, 13 Mar 2024 18:06:39 +0100 Subject: [PATCH 36/38] StrictNamedArgs --- named_args.go | 58 +++++++++++++++++++++++++++++++++------------- named_args_test.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 16 deletions(-) diff --git a/named_args.go b/named_args.go index 8367fc63..c88991ee 100644 --- a/named_args.go +++ b/named_args.go @@ -2,6 +2,7 @@ package pgx import ( "context" + "fmt" "strconv" "strings" "unicode/utf8" @@ -21,6 +22,34 @@ type NamedArgs map[string]any // RewriteQuery implements the QueryRewriter interface. func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(na, sql, false) +} + +// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all +// named arguments that the sql query uses, and no extra arguments. +type StrictNamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(sna, sql, true) +} + +type namedArg string + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []any + + nameToOrdinal map[namedArg]int +} + +type stateFn func(*sqlLexer) stateFn + +func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) { l := &sqlLexer{ src: sql, stateFn: rawState, @@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar newArgs = make([]any, len(l.nameToOrdinal)) for name, ordinal := range l.nameToOrdinal { - newArgs[ordinal-1] = na[string(name)] + var found bool + newArgs[ordinal-1], found = na[string(name)] + if isStrict && !found { + return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name) + } + } + + if isStrict { + for name := range na { + if _, found := l.nameToOrdinal[namedArg(name)]; !found { + return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name) + } + } } return sb.String(), newArgs, nil } -type namedArg string - -type sqlLexer struct { - src string - start int - pos int - nested int // multiline comment nesting level. - stateFn stateFn - parts []any - - nameToOrdinal map[namedArg]int -} - -type stateFn func(*sqlLexer) stateFn - func rawState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) diff --git a/named_args_test.go b/named_args_test.go index 49ac817d..8cab2f4d 100644 --- a/named_args_test.go +++ b/named_args_test.go @@ -93,6 +93,18 @@ func TestNamedArgsRewriteQuery(t *testing.T) { where id = $1;`, expectedArgs: []any{int32(42)}, }, + { + sql: "extra provided argument", + namedArgs: pgx.NamedArgs{"extra": int32(1)}, + expectedSQL: "extra provided argument", + expectedArgs: []any{}, + }, + { + sql: "@missing argument", + namedArgs: pgx.NamedArgs{}, + expectedSQL: "$1 argument", + expectedArgs: []any{nil}, + }, // test comments and quotes } { @@ -102,3 +114,49 @@ func TestNamedArgsRewriteQuery(t *testing.T) { assert.Equalf(t, tt.expectedArgs, args, "%d", i) } } + +func TestStrictNamedArgsRewriteQuery(t *testing.T) { + t.Parallel() + + for i, tt := range []struct { + sql string + namedArgs pgx.StrictNamedArgs + expectedSQL string + expectedArgs []any + isExpectedError bool + }{ + { + sql: "no arguments", + namedArgs: pgx.StrictNamedArgs{}, + expectedSQL: "no arguments", + expectedArgs: []any{}, + isExpectedError: false, + }, + { + sql: "@all @matches", + namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)}, + expectedSQL: "$1 $2", + expectedArgs: []any{int32(1), int32(2)}, + isExpectedError: false, + }, + { + sql: "extra provided argument", + namedArgs: pgx.StrictNamedArgs{"extra": int32(1)}, + isExpectedError: true, + }, + { + sql: "@missing argument", + namedArgs: pgx.StrictNamedArgs{}, + isExpectedError: true, + }, + } { + sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil) + if tt.isExpectedError { + assert.Errorf(t, err, "%d", i) + } else { + require.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expectedSQL, sql, "%d", i) + assert.Equalf(t, tt.expectedArgs, args, "%d", i) + } + } +} From 221ad1b84c5dd91e329bb284e604b0810612cdc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20P=C3=A9rez-Aradros=20Herce?= Date: Mon, 18 Mar 2024 17:29:04 +0100 Subject: [PATCH 37/38] Add support for macaddr8 type Postgres also has a `macaddr8` type, this PR adds support for it, using the same codec as `macaddr` --- pgtype/macaddr_test.go | 5 +++++ pgtype/pgtype.go | 1 + pgtype/pgtype_default.go | 1 + 3 files changed, 7 insertions(+) diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index 84cf970f..03dd1db5 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -46,6 +46,11 @@ func TestMacaddrCodec(t *testing.T) { new(string), isExpectedEq("01:23:45:67:89:ab"), }, + { + mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"), + new(string), + isExpectedEq("01:23:45:67:89:ab:01:08"), + }, {nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 534ef6d1..d23ebc6c 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -41,6 +41,7 @@ const ( CircleOID = 718 CircleArrayOID = 719 UnknownOID = 705 + Macaddr8OID = 774 MacaddrOID = 829 InetOID = 869 BoolArrayOID = 1000 diff --git a/pgtype/pgtype_default.go b/pgtype/pgtype_default.go index c21ac081..d56d95bd 100644 --- a/pgtype/pgtype_default.go +++ b/pgtype/pgtype_default.go @@ -70,6 +70,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + defaultMap.RegisterType(&Type{Name: "macaddr8", OID: Macaddr8OID, Codec: MacaddrCodec{}}) defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) From 78b22c3d2f2aa4029e7b3a8b3062fc656fb34c2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20P=C3=A9rez-Aradros=20Herce?= Date: Mon, 18 Mar 2024 22:20:55 +0100 Subject: [PATCH 38/38] fix tests --- pgtype/macaddr_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index 03dd1db5..58149c87 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -46,6 +46,20 @@ func TestMacaddrCodec(t *testing.T) { new(string), isExpectedEq("01:23:45:67:89:ab"), }, + {nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))}, + }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "macaddr8", []pgxtest.ValueRoundTripTest{ + { + mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"), + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")), + }, + { + "01:23:45:67:89:ab:01:08", + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")), + }, { mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"), new(string),