Merge remote-tracking branch 'upstream/master'
This commit is contained in:
@@ -9,14 +9,11 @@ on:
|
||||
jobs:
|
||||
test:
|
||||
name: Test
|
||||
# Note: The TLS tests are rather finicky. It seems that openssl 3 encrypts certificates differently than older
|
||||
# openssl and it does it in a way Go and/or pgx ssl handling code can't handle. So stick with Ubuntu 20.04 until
|
||||
# that is figured out.
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ["1.20", "1.21"]
|
||||
go-version: ["1.21", "1.22"]
|
||||
pg-version: [12, 13, 14, 15, 16, cockroachdb]
|
||||
include:
|
||||
- pg-version: 12
|
||||
@@ -77,7 +74,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
|
||||
@@ -109,7 +106,8 @@ jobs:
|
||||
git diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: go test -v -race ./...
|
||||
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
|
||||
run: go test -parallel=1 -race ./...
|
||||
env:
|
||||
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
|
||||
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
|
||||
@@ -127,7 +125,7 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ["1.20", "1.21"]
|
||||
go-version: ["1.21", "1.22"]
|
||||
|
||||
steps:
|
||||
- name: Setup PostgreSQL
|
||||
@@ -140,7 +138,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
|
||||
@@ -152,6 +150,7 @@ jobs:
|
||||
shell: bash
|
||||
|
||||
- name: Test
|
||||
run: go test -v -race ./...
|
||||
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
|
||||
run: go test -parallel=1 -race ./...
|
||||
env:
|
||||
PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }}
|
||||
|
||||
@@ -1,3 +1,39 @@
|
||||
# 5.5.5 (March 9, 2024)
|
||||
|
||||
Use spaces instead of parentheses for SQL sanitization.
|
||||
|
||||
This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as
|
||||
`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed.
|
||||
|
||||
# 5.5.4 (March 4, 2024)
|
||||
|
||||
Fix CVE-2024-27304
|
||||
|
||||
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
|
||||
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
|
||||
attacker's control.
|
||||
|
||||
Thanks to Paul Gerste for reporting this issue.
|
||||
|
||||
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
|
||||
* Fix simple protocol encoding of json.RawMessage
|
||||
* Fix *Pipeline.getResults should close pipeline on error
|
||||
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
|
||||
* Fix deallocation of invalidated cached statements in a transaction
|
||||
* Handle invalid sslkey file
|
||||
* Fix scan float4 into sql.Scanner
|
||||
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
|
||||
|
||||
# 5.5.3 (February 3, 2024)
|
||||
|
||||
* Fix: prepared statement already exists
|
||||
* Improve CopyFrom auto-conversion of text-ish values
|
||||
* Add ltree type support (Florent Viel)
|
||||
* Make some properties of Batch and QueuedQuery public (Pavlo Golub)
|
||||
* Add AppendRows function (Edoardo Spadolini)
|
||||
* Optimize convert UUID [16]byte to string (Kirill Malikov)
|
||||
* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar)
|
||||
|
||||
# 5.5.2 (January 13, 2024)
|
||||
|
||||
* Allow NamedArgs to start with underscore
|
||||
|
||||
+2
-16
@@ -79,20 +79,11 @@ echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql
|
||||
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
||||
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
|
||||
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
|
||||
cp testsetup/ca.cnf .testdb
|
||||
cp testsetup/localhost.cnf .testdb
|
||||
cp testsetup/pgx_sslcert.cnf .testdb
|
||||
|
||||
cd .testdb
|
||||
|
||||
# Generate a CA public / private key pair.
|
||||
openssl genrsa -out ca.key 4096
|
||||
openssl req -x509 -config ca.cnf -new -nodes -key ca.key -sha256 -days 365 -subj '/O=pgx-test-root' -out ca.pem
|
||||
|
||||
# Generate the certificate for localhost (the server).
|
||||
openssl genrsa -out localhost.key 2048
|
||||
openssl req -new -config localhost.cnf -key localhost.key -out localhost.csr
|
||||
openssl x509 -req -in localhost.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out localhost.crt -days 364 -sha256 -extfile localhost.cnf -extensions v3_req
|
||||
# Generate CA, server, and encrypted client certificates.
|
||||
go run ../testsetup/generate_certs.go
|
||||
|
||||
# Copy certificates to server directory and set permissions.
|
||||
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
|
||||
@@ -100,11 +91,6 @@ cp localhost.key $POSTGRESQL_DATA_DIR/server.key
|
||||
chmod 600 $POSTGRESQL_DATA_DIR/server.key
|
||||
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
|
||||
|
||||
# Generate the certificate for client authentication.
|
||||
openssl genrsa -des3 -out pgx_sslcert.key -passout pass:certpw 2048
|
||||
openssl req -new -config pgx_sslcert.cnf -key pgx_sslcert.key -passin pass:certpw -out pgx_sslcert.csr
|
||||
openssl x509 -req -in pgx_sslcert.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out pgx_sslcert.crt -days 363 -sha256 -extfile pgx_sslcert.cnf -extensions v3_req
|
||||
|
||||
cd ..
|
||||
```
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ pgerrcode contains constants for the PostgreSQL error codes.
|
||||
|
||||
- [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
|
||||
- [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
|
||||
- [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
|
||||
- [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
|
||||
|
||||
## Adapters for 3rd Party Tracers
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
|
||||
// QueuedQuery is a query that has been queued for execution via a Batch.
|
||||
type QueuedQuery struct {
|
||||
query string
|
||||
arguments []any
|
||||
SQL string
|
||||
Arguments []any
|
||||
fn batchItemFunc
|
||||
sd *pgconn.StatementDescription
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
|
||||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
// unnecessary network round trips. A Batch must only be sent once.
|
||||
type Batch struct {
|
||||
queuedQueries []*QueuedQuery
|
||||
QueuedQueries []*QueuedQuery
|
||||
}
|
||||
|
||||
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
|
||||
@@ -65,16 +65,16 @@ type Batch struct {
|
||||
// connection's DefaultQueryExecMode.
|
||||
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
|
||||
qq := &QueuedQuery{
|
||||
query: query,
|
||||
arguments: arguments,
|
||||
SQL: query,
|
||||
Arguments: arguments,
|
||||
}
|
||||
b.queuedQueries = append(b.queuedQueries, qq)
|
||||
b.QueuedQueries = append(b.QueuedQueries, qq)
|
||||
return qq
|
||||
}
|
||||
|
||||
// Len returns number of queries that have been queued so far.
|
||||
func (b *Batch) Len() int {
|
||||
return len(b.queuedQueries)
|
||||
return len(b.QueuedQueries)
|
||||
}
|
||||
|
||||
type BatchResults interface {
|
||||
@@ -227,9 +227,9 @@ func (br *batchResults) Close() error {
|
||||
}
|
||||
|
||||
// Read and run fn for all remaining items
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
||||
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||
if br.b.QueuedQueries[br.qqIdx].fn != nil {
|
||||
err := br.b.QueuedQueries[br.qqIdx].fn(br)
|
||||
if err != nil {
|
||||
br.err = err
|
||||
}
|
||||
@@ -253,10 +253,10 @@ func (br *batchResults) earlyError() error {
|
||||
}
|
||||
|
||||
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
bi := br.b.queuedQueries[br.qqIdx]
|
||||
query = bi.query
|
||||
args = bi.arguments
|
||||
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||
bi := br.b.QueuedQueries[br.qqIdx]
|
||||
query = bi.SQL
|
||||
args = bi.Arguments
|
||||
ok = true
|
||||
br.qqIdx++
|
||||
}
|
||||
@@ -396,9 +396,9 @@ func (br *pipelineBatchResults) Close() error {
|
||||
}
|
||||
|
||||
// Read and run fn for all remaining items
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
||||
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||
if br.b.QueuedQueries[br.qqIdx].fn != nil {
|
||||
err := br.b.QueuedQueries[br.qqIdx].fn(br)
|
||||
if err != nil {
|
||||
br.err = err
|
||||
}
|
||||
@@ -422,10 +422,10 @@ func (br *pipelineBatchResults) earlyError() error {
|
||||
}
|
||||
|
||||
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
bi := br.b.queuedQueries[br.qqIdx]
|
||||
query = bi.query
|
||||
args = bi.arguments
|
||||
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
|
||||
bi := br.b.QueuedQueries[br.qqIdx]
|
||||
query = bi.SQL
|
||||
args = bi.Arguments
|
||||
ok = true
|
||||
br.qqIdx++
|
||||
}
|
||||
|
||||
+2
-13
@@ -16,14 +16,8 @@ then
|
||||
|
||||
cd testsetup
|
||||
|
||||
# Generate a CA public / private key pair.
|
||||
openssl genrsa -out ca.key 4096
|
||||
openssl req -x509 -config ca.cnf -new -nodes -key ca.key -sha256 -days 365 -subj '/O=pgx-test-root' -out ca.pem
|
||||
|
||||
# Generate the certificate for localhost (the server).
|
||||
openssl genrsa -out localhost.key 2048
|
||||
openssl req -new -config localhost.cnf -key localhost.key -out localhost.csr
|
||||
openssl x509 -req -in localhost.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out localhost.crt -days 364 -sha256 -extfile localhost.cnf -extensions v3_req
|
||||
# Generate CA, server, and encrypted client certificates.
|
||||
go run generate_certs.go
|
||||
|
||||
# Copy certificates to server directory and set permissions.
|
||||
sudo cp ca.pem /var/lib/postgresql/$PGVERSION/main/root.crt
|
||||
@@ -34,11 +28,6 @@ then
|
||||
sudo cp localhost.crt /var/lib/postgresql/$PGVERSION/main/server.crt
|
||||
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.crt
|
||||
|
||||
# Generate the certificate for client authentication.
|
||||
openssl genrsa -des3 -out pgx_sslcert.key -passout pass:certpw 2048
|
||||
openssl req -new -config pgx_sslcert.cnf -key pgx_sslcert.key -passin pass:certpw -out pgx_sslcert.csr
|
||||
openssl x509 -req -in pgx_sslcert.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out pgx_sslcert.crt -days 363 -sha256 -extfile pgx_sslcert.cnf -extensions v3_req
|
||||
|
||||
cp ca.pem /tmp
|
||||
cp pgx_sslcert.key /tmp
|
||||
cp pgx_sslcert.crt /tmp
|
||||
|
||||
@@ -903,10 +903,10 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
for _, bi := range b.queuedQueries {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
var queryRewriter QueryRewriter
|
||||
sql := bi.query
|
||||
arguments := bi.arguments
|
||||
sql := bi.SQL
|
||||
arguments := bi.Arguments
|
||||
|
||||
optionLoop:
|
||||
for len(arguments) > 0 {
|
||||
@@ -928,8 +928,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
}
|
||||
}
|
||||
|
||||
bi.query = sql
|
||||
bi.arguments = arguments
|
||||
bi.SQL = sql
|
||||
bi.Arguments = arguments
|
||||
}
|
||||
|
||||
// TODO: changing mode per batch? Update Batch.Queue function comment when implemented
|
||||
@@ -939,8 +939,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
}
|
||||
|
||||
// All other modes use extended protocol and thus can use prepared statements.
|
||||
for _, bi := range b.queuedQueries {
|
||||
if sd, ok := c.preparedStatements[bi.query]; ok {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
if sd, ok := c.preparedStatements[bi.SQL]; ok {
|
||||
bi.sd = sd
|
||||
}
|
||||
}
|
||||
@@ -961,11 +961,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
|
||||
var sb strings.Builder
|
||||
for i, bi := range b.queuedQueries {
|
||||
for i, bi := range b.QueuedQueries {
|
||||
if i > 0 {
|
||||
sb.WriteByte(';')
|
||||
}
|
||||
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
|
||||
sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
@@ -984,21 +984,21 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc
|
||||
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
for _, bi := range b.queuedQueries {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
sd := bi.sd
|
||||
if sd != nil {
|
||||
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
||||
err := c.eqb.Build(c.typeMap, sd, bi.Arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
} else {
|
||||
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
|
||||
err := c.eqb.Build(c.typeMap, nil, bi.Arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1023,18 +1023,18 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.queuedQueries {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
if bi.sd == nil {
|
||||
sd := c.statementCache.Get(bi.query)
|
||||
sd := c.statementCache.Get(bi.SQL)
|
||||
if sd != nil {
|
||||
bi.sd = sd
|
||||
} else {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd = &pgconn.StatementDescription{
|
||||
Name: stmtcache.StatementName(bi.query),
|
||||
SQL: bi.query,
|
||||
Name: stmtcache.StatementName(bi.SQL),
|
||||
SQL: bi.SQL,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
@@ -1055,17 +1055,17 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.queuedQueries {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
if bi.sd == nil {
|
||||
sd := c.descriptionCache.Get(bi.query)
|
||||
sd := c.descriptionCache.Get(bi.SQL)
|
||||
if sd != nil {
|
||||
bi.sd = sd
|
||||
} else {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd = &pgconn.StatementDescription{
|
||||
SQL: bi.query,
|
||||
SQL: bi.SQL,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
@@ -1082,13 +1082,13 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.queuedQueries {
|
||||
for _, bi := range b.QueuedQueries {
|
||||
if bi.sd == nil {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd := &pgconn.StatementDescription{
|
||||
SQL: bi.query,
|
||||
SQL: bi.SQL,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
@@ -1154,11 +1154,11 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
|
||||
}
|
||||
|
||||
// Queue the queries.
|
||||
for _, bi := range b.queuedQueries {
|
||||
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
|
||||
for _, bi := range b.QueuedQueries {
|
||||
err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
|
||||
if err != nil {
|
||||
// we wrap the error so we the user can understand which query failed inside the batch
|
||||
err = fmt.Errorf("error building query %s: %w", bi.query, err)
|
||||
err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
|
||||
@@ -1203,7 +1203,15 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
|
||||
return sanitize.SanitizeSQL(sql, valueArgs...)
|
||||
}
|
||||
|
||||
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration.
|
||||
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be
|
||||
// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular,
|
||||
// typeName must be one of the following:
|
||||
// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered.
|
||||
// - A composite type name where all field types are already registered.
|
||||
// - A domain type name where the base type is already registered.
|
||||
// - An enum type name.
|
||||
// - A range type name where the element type is already registered.
|
||||
// - A multirange type name where the element type is already registered.
|
||||
func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
|
||||
var oid uint32
|
||||
|
||||
@@ -1346,17 +1354,17 @@ order by attnum`,
|
||||
}
|
||||
|
||||
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
|
||||
if c.pgConn.TxStatus() != 'I' {
|
||||
if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.descriptionCache != nil {
|
||||
c.descriptionCache.HandleInvalidated()
|
||||
c.descriptionCache.RemoveInvalidated()
|
||||
}
|
||||
|
||||
var invalidatedStatements []*pgconn.StatementDescription
|
||||
if c.statementCache != nil {
|
||||
invalidatedStatements = c.statementCache.HandleInvalidated()
|
||||
invalidatedStatements = c.statementCache.GetInvalidated()
|
||||
}
|
||||
|
||||
if len(invalidatedStatements) == 0 {
|
||||
@@ -1368,7 +1376,6 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
|
||||
|
||||
for _, sd := range invalidatedStatements {
|
||||
pipeline.SendDeallocate(sd.Name)
|
||||
delete(c.preparedStatements, sd.Name)
|
||||
}
|
||||
|
||||
err := pipeline.Sync()
|
||||
@@ -1381,5 +1388,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
|
||||
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
||||
}
|
||||
|
||||
c.statementCache.RemoveInvalidated()
|
||||
for _, sd := range invalidatedStatements {
|
||||
delete(c.preparedStatements, sd.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1338,3 +1338,73 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
|
||||
t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not")
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1847
|
||||
func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "CockroachDB returns decimal instead of integer for integer division")
|
||||
|
||||
var n int32
|
||||
err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, n)
|
||||
|
||||
// Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was
|
||||
// encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn
|
||||
// we could call conn.statementCache.InvalidateAll() instead.
|
||||
err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n)
|
||||
require.Error(t, err)
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(ctx)
|
||||
cancel2()
|
||||
err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
|
||||
err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, n)
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1847
|
||||
func TestConnDeallocateInvalidatedCachedStatementsInTransactionWithBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
connString := os.Getenv("PGX_TEST_DATABASE")
|
||||
config := mustParseConfig(t, connString)
|
||||
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
|
||||
config.StatementCacheCapacity = 2
|
||||
|
||||
conn, err := pgx.ConnectConfig(ctx, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx, "select $1::int + 1", 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tx.Exec(ctx, "select $1::int + 2", 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// This should invalidate the first cached statement.
|
||||
_, err = tx.Exec(ctx, "select $1::int + 3", 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select $1::int + 1", 1)
|
||||
err = tx.SendBatch(ctx, batch).Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tx.Rollback(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
@@ -803,6 +803,40 @@ func TestConnCopyFromAutomaticStringConversion(t *testing.T) {
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/discussions/1891
|
||||
func TestConnCopyFromAutomaticStringConversionArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a numeric[]
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{[]string{"42"}},
|
||||
{[]string{"7"}},
|
||||
{[]string{"8", "9"}},
|
||||
{[][]string{{"10", "11"}, {"12", "13"}}},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(inputRows), copyCount)
|
||||
|
||||
// Test reads as int64 and flattened array for simplicity.
|
||||
rows, _ := conn.Query(ctx, "select * from foo")
|
||||
nums, err := pgx.CollectRows(rows, pgx.RowTo[[]int64])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]int64{{42}, {7}, {8, 9}, {10, 11, 12, 13}}, nums)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestCopyFromFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ module github.com/andoma-go/pgx/v5
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/andoma-go/pgpassfile v1.0.0
|
||||
github.com/andoma-go/pgpassfile v0.0.0-20240115130830-7bdd00f68544
|
||||
github.com/andoma-go/pgservicefile v0.0.0-20240115131304-4a01ebf23c42
|
||||
github.com/andoma-go/puddle/v2 v2.2.1
|
||||
github.com/andoma-go/puddle/v2 v2.0.0-20240328142435-357666cb6fa1
|
||||
github.com/stretchr/testify v1.8.1
|
||||
golang.org/x/crypto v0.17.0
|
||||
golang.org/x/text v0.14.0
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
github.com/andoma-go/pgpassfile v1.0.0 h1:IJZAs6b/3pmEnq0kAvBWh2qEPsQOleHVIzMzj8WwT4w=
|
||||
github.com/andoma-go/pgpassfile v1.0.0/go.mod h1:JWSeNzz3oUhysdQgq1OL4PyV3R4QW/KyPvqqEykyN88=
|
||||
github.com/andoma-go/pgpassfile v0.0.0-20240115130830-7bdd00f68544 h1:zw0WuRyP2Awzl63MI2VwMXSM/CsNqwygHg/CbySE1ls=
|
||||
github.com/andoma-go/pgpassfile v0.0.0-20240115130830-7bdd00f68544/go.mod h1:JWSeNzz3oUhysdQgq1OL4PyV3R4QW/KyPvqqEykyN88=
|
||||
github.com/andoma-go/pgservicefile v0.0.0-20240115131304-4a01ebf23c42 h1:TpYPPFFHiqFDM0luTfDiHBdGSgYU+uloD+FaA87BBRk=
|
||||
github.com/andoma-go/pgservicefile v0.0.0-20240115131304-4a01ebf23c42/go.mod h1:iRoNsjH6Wp9dCo0oiT1geVOjYusx6RUIdzCJNktFso0=
|
||||
github.com/andoma-go/puddle/v2 v2.2.1 h1:cobxhnZmYsynXC9k8xcJd97ytlCa/Pe5kgj69pgncrE=
|
||||
github.com/andoma-go/puddle/v2 v2.2.1/go.mod h1:iWHUHOdNa1/WJ6MyJAZ5qeTI/sJMbjVK/Gw4JLjh4Dw=
|
||||
github.com/andoma-go/puddle/v2 v2.0.0-20240328142435-357666cb6fa1 h1:3/6Uu7EWnHeHAwZ9tfytqJy+1x8LTtYrsWGczhMJ4uc=
|
||||
github.com/andoma-go/puddle/v2 v2.0.0-20240328142435-357666cb6fa1/go.mod h1:iWHUHOdNa1/WJ6MyJAZ5qeTI/sJMbjVK/Gw4JLjh4Dw=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
|
||||
@@ -63,6 +63,10 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
||||
return "", fmt.Errorf("invalid arg type: %T", arg)
|
||||
}
|
||||
argUse[argIdx] = true
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
str = " " + str + " "
|
||||
default:
|
||||
return "", fmt.Errorf("invalid Part type: %T", part)
|
||||
}
|
||||
|
||||
@@ -132,47 +132,57 @@ func TestQuerySanitize(t *testing.T) {
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{int64(42)},
|
||||
expected: `select 42`,
|
||||
expected: `select 42 `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{float64(1.23)},
|
||||
expected: `select 1.23`,
|
||||
expected: `select 1.23 `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{true},
|
||||
expected: `select true`,
|
||||
expected: `select true `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{[]byte{0, 1, 2, 3, 255}},
|
||||
expected: `select '\x00010203ff'`,
|
||||
expected: `select '\x00010203ff' `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{nil},
|
||||
expected: `select null`,
|
||||
expected: `select null `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{"foobar"},
|
||||
expected: `select 'foobar'`,
|
||||
expected: `select 'foobar' `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{"foo'bar"},
|
||||
expected: `select 'foo''bar'`,
|
||||
expected: `select 'foo''bar' `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []any{`foo\'bar`},
|
||||
expected: `select 'foo\''bar'`,
|
||||
expected: `select 'foo\''bar' `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
|
||||
args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
|
||||
expected: `insert '2020-03-01 23:59:59.999999Z'`,
|
||||
expected: `insert '2020-03-01 23:59:59.999999Z' `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
|
||||
args: []any{int64(-1)},
|
||||
expected: `select 1- -1 `,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
|
||||
args: []any{float64(-1)},
|
||||
expected: `select 1- -1 `,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -81,12 +81,16 @@ func (c *LRUCache) InvalidateAll() {
|
||||
c.l = list.New()
|
||||
}
|
||||
|
||||
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
|
||||
// Typically, the caller will then deallocate them.
|
||||
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||
invalidStmts := c.invalidStmts
|
||||
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
|
||||
return c.invalidStmts
|
||||
}
|
||||
|
||||
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||
// never seen by the call to GetInvalidated.
|
||||
func (c *LRUCache) RemoveInvalidated() {
|
||||
c.invalidStmts = nil
|
||||
return invalidStmts
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
|
||||
@@ -29,8 +29,13 @@ type Cache interface {
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
InvalidateAll()
|
||||
|
||||
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
|
||||
HandleInvalidated() []*pgconn.StatementDescription
|
||||
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||
GetInvalidated() []*pgconn.StatementDescription
|
||||
|
||||
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||
// never seen by the call to GetInvalidated.
|
||||
RemoveInvalidated()
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
Len() int
|
||||
|
||||
@@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() {
|
||||
c.m = make(map[string]*pgconn.StatementDescription)
|
||||
}
|
||||
|
||||
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||
invalidStmts := c.invalidStmts
|
||||
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
|
||||
func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
|
||||
return c.invalidStmts
|
||||
}
|
||||
|
||||
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
|
||||
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
|
||||
// never seen by the call to GetInvalidated.
|
||||
func (c *UnlimitedCache) RemoveInvalidated() {
|
||||
c.invalidStmts = nil
|
||||
return invalidStmts
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
|
||||
+55
-26
@@ -6,6 +6,11 @@ import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
|
||||
// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data
|
||||
// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB.
|
||||
var maxLargeObjectMessageLength = 1024*1024*1024 - 1024
|
||||
|
||||
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
|
||||
// was created.
|
||||
//
|
||||
@@ -67,41 +72,65 @@ type LargeObject struct {
|
||||
}
|
||||
|
||||
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
|
||||
//
|
||||
// Write is implemented with a single call to lowrite. The PostgreSQL wire protocol has a limit of 1 GB - 1 per message.
|
||||
// See definition of PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data in the message,
|
||||
// len(p) should be no larger than 1 GB - 1 KB.
|
||||
func (o *LargeObject) Write(p []byte) (int, error) {
|
||||
var n int
|
||||
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n)
|
||||
if err != nil {
|
||||
return n, err
|
||||
nTotal := 0
|
||||
for {
|
||||
expected := len(p) - nTotal
|
||||
if expected == 0 {
|
||||
break
|
||||
} else if expected > maxLargeObjectMessageLength {
|
||||
expected = maxLargeObjectMessageLength
|
||||
}
|
||||
|
||||
var n int
|
||||
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n)
|
||||
if err != nil {
|
||||
return nTotal, err
|
||||
}
|
||||
|
||||
if n < 0 {
|
||||
return nTotal, errors.New("failed to write to large object")
|
||||
}
|
||||
|
||||
nTotal += n
|
||||
|
||||
if n < expected {
|
||||
return nTotal, errors.New("short write to large object")
|
||||
} else if n > expected {
|
||||
return nTotal, errors.New("invalid write to large object")
|
||||
}
|
||||
}
|
||||
|
||||
if n < 0 {
|
||||
return 0, errors.New("failed to write to large object")
|
||||
}
|
||||
|
||||
return n, nil
|
||||
return nTotal, nil
|
||||
}
|
||||
|
||||
// Read reads up to len(p) bytes into p returning the number of bytes read.
|
||||
//
|
||||
// Read is implemented with a single call to loread. PostgreSQL internally allocates a single buffer for the response.
|
||||
// The largest buffer PostgreSQL will allocate is 1 GB - 1. See definition of MaxAllocSize in the PostgreSQL source
|
||||
// code. To allow for the other data in the message, len(p) should be no larger than 1 GB - 1 KB.
|
||||
func (o *LargeObject) Read(p []byte) (int, error) {
|
||||
var res []byte
|
||||
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res)
|
||||
copy(p, res)
|
||||
if err != nil {
|
||||
return len(res), err
|
||||
nTotal := 0
|
||||
for {
|
||||
expected := len(p) - nTotal
|
||||
if expected == 0 {
|
||||
break
|
||||
} else if expected > maxLargeObjectMessageLength {
|
||||
expected = maxLargeObjectMessageLength
|
||||
}
|
||||
|
||||
var res []byte
|
||||
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
|
||||
copy(p[nTotal:], res)
|
||||
nTotal += len(res)
|
||||
if err != nil {
|
||||
return nTotal, err
|
||||
}
|
||||
|
||||
if len(res) < expected {
|
||||
return nTotal, io.EOF
|
||||
} else if len(res) > expected {
|
||||
return nTotal, errors.New("invalid read of large object")
|
||||
}
|
||||
}
|
||||
|
||||
if len(res) < len(p) {
|
||||
err = io.EOF
|
||||
}
|
||||
return len(res), err
|
||||
return nTotal, nil
|
||||
}
|
||||
|
||||
// Seek moves the current location pointer to the new location specified by offset.
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// SetMaxLargeObjectMessageLength sets internal maxLargeObjectMessageLength variable
|
||||
// to the given length for the duration of the test.
|
||||
//
|
||||
// Tests using this helper should not use t.Parallel().
|
||||
func SetMaxLargeObjectMessageLength(t *testing.T, length int) {
|
||||
t.Helper()
|
||||
|
||||
original := maxLargeObjectMessageLength
|
||||
t.Cleanup(func() {
|
||||
maxLargeObjectMessageLength = original
|
||||
})
|
||||
|
||||
maxLargeObjectMessageLength = length
|
||||
}
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
)
|
||||
|
||||
func TestLargeObjects(t *testing.T) {
|
||||
t.Parallel()
|
||||
// We use a very short limit to test chunking logic.
|
||||
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
@@ -34,7 +35,8 @@ func TestLargeObjects(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLargeObjectsSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
// We use a very short limit to test chunking logic.
|
||||
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
@@ -160,7 +162,8 @@ func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) {
|
||||
}
|
||||
|
||||
func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
||||
t.Parallel()
|
||||
// We use a very short limit to test chunking logic.
|
||||
pgx.SetMaxLargeObjectMessageLength(t, 2)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
+42
-16
@@ -2,6 +2,7 @@ package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
@@ -21,6 +22,34 @@ type NamedArgs map[string]any
|
||||
|
||||
// RewriteQuery implements the QueryRewriter interface.
|
||||
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||
return rewriteQuery(na, sql, false)
|
||||
}
|
||||
|
||||
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
|
||||
// named arguments that the sql query uses, and no extra arguments.
|
||||
type StrictNamedArgs map[string]any
|
||||
|
||||
// RewriteQuery implements the QueryRewriter interface.
|
||||
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||
return rewriteQuery(sna, sql, true)
|
||||
}
|
||||
|
||||
type namedArg string
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
nested int // multiline comment nesting level.
|
||||
stateFn stateFn
|
||||
parts []any
|
||||
|
||||
nameToOrdinal map[namedArg]int
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
@@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
|
||||
|
||||
newArgs = make([]any, len(l.nameToOrdinal))
|
||||
for name, ordinal := range l.nameToOrdinal {
|
||||
newArgs[ordinal-1] = na[string(name)]
|
||||
var found bool
|
||||
newArgs[ordinal-1], found = na[string(name)]
|
||||
if isStrict && !found {
|
||||
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
|
||||
}
|
||||
}
|
||||
|
||||
if isStrict {
|
||||
for name := range na {
|
||||
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
|
||||
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), newArgs, nil
|
||||
}
|
||||
|
||||
type namedArg string
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
nested int // multiline comment nesting level.
|
||||
stateFn stateFn
|
||||
parts []any
|
||||
|
||||
nameToOrdinal map[namedArg]int
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
|
||||
@@ -93,6 +93,18 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||
where id = $1;`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: "extra provided argument",
|
||||
namedArgs: pgx.NamedArgs{"extra": int32(1)},
|
||||
expectedSQL: "extra provided argument",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "@missing argument",
|
||||
namedArgs: pgx.NamedArgs{},
|
||||
expectedSQL: "$1 argument",
|
||||
expectedArgs: []any{nil},
|
||||
},
|
||||
|
||||
// test comments and quotes
|
||||
} {
|
||||
@@ -102,3 +114,49 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrictNamedArgsRewriteQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i, tt := range []struct {
|
||||
sql string
|
||||
namedArgs pgx.StrictNamedArgs
|
||||
expectedSQL string
|
||||
expectedArgs []any
|
||||
isExpectedError bool
|
||||
}{
|
||||
{
|
||||
sql: "no arguments",
|
||||
namedArgs: pgx.StrictNamedArgs{},
|
||||
expectedSQL: "no arguments",
|
||||
expectedArgs: []any{},
|
||||
isExpectedError: false,
|
||||
},
|
||||
{
|
||||
sql: "@all @matches",
|
||||
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
|
||||
expectedSQL: "$1 $2",
|
||||
expectedArgs: []any{int32(1), int32(2)},
|
||||
isExpectedError: false,
|
||||
},
|
||||
{
|
||||
sql: "extra provided argument",
|
||||
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
|
||||
isExpectedError: true,
|
||||
},
|
||||
{
|
||||
sql: "@missing argument",
|
||||
namedArgs: pgx.StrictNamedArgs{},
|
||||
isExpectedError: true,
|
||||
},
|
||||
} {
|
||||
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
|
||||
if tt.isExpectedError {
|
||||
assert.Errorf(t, err, "%d", i)
|
||||
} else {
|
||||
require.NoErrorf(t, err, "%d", i)
|
||||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -721,6 +721,9 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
||||
return nil, fmt.Errorf("unable to read sslkey: %w", err)
|
||||
}
|
||||
block, _ := pem.Decode(buf)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode sslkey")
|
||||
}
|
||||
var pemKey []byte
|
||||
var decryptedKey []byte
|
||||
var decryptedError error
|
||||
|
||||
+43
-5
@@ -1674,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
|
||||
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
|
||||
type Batch struct {
|
||||
buf []byte
|
||||
err error
|
||||
}
|
||||
|
||||
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
|
||||
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
|
||||
batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
|
||||
}
|
||||
|
||||
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
|
||||
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
|
||||
batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
|
||||
batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
|
||||
batch.buf = (&pgproto3.Execute{}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
|
||||
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
|
||||
// multiple queries in a single round trip than using pipeline mode.
|
||||
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
|
||||
if batch.err != nil {
|
||||
return &MultiResultReader{
|
||||
closed: true,
|
||||
err: batch.err,
|
||||
}
|
||||
}
|
||||
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return &MultiResultReader{
|
||||
closed: true,
|
||||
@@ -1718,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
}
|
||||
|
||||
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
multiResult.closed = true
|
||||
multiResult.err = batch.err
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
}
|
||||
|
||||
pgConn.enterPotentialWriteReadDeadlock()
|
||||
defer pgConn.exitPotentialWriteReadDeadlock()
|
||||
@@ -2094,6 +2130,8 @@ func (p *Pipeline) getResults() (results any, err error) {
|
||||
for {
|
||||
msg, err := p.conn.receiveMessage()
|
||||
if err != nil {
|
||||
p.closed = true
|
||||
p.err = err
|
||||
p.conn.asyncClose()
|
||||
return nil, normalizeTimeoutError(p.ctx, err)
|
||||
}
|
||||
|
||||
+93
-3
@@ -3363,9 +3363,9 @@ func TestSNISupport(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
|
||||
srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
|
||||
srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
|
||||
srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
|
||||
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
|
||||
srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))
|
||||
|
||||
serverSNINameChan <- sniHost
|
||||
}()
|
||||
@@ -3389,3 +3389,93 @@ func TestSNISupport(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1920
|
||||
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
steps := pgmock.AcceptUnauthenticatedConnRequestSteps()
|
||||
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
||||
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
||||
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
||||
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
||||
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
||||
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
||||
steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
||||
{Name: []byte("mock")},
|
||||
}}))
|
||||
steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
|
||||
// We shouldn't get anything after the first fatal error. But the reported issue was with PgBouncer so maybe that
|
||||
// causes the issue. Anyway, a FATAL error after the connection had already been killed could cause a panic.
|
||||
steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
|
||||
|
||||
script := &pgmock.Script{Steps: steps}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
serverKeepAlive := make(chan struct{})
|
||||
defer close(serverKeepAlive)
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrChan)
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(59 * time.Second))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = script.Run(pgproto3.NewBackend(conn, conn))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
<-serverKeepAlive
|
||||
}()
|
||||
|
||||
parts := strings.Split(ln.Addr().String(), ":")
|
||||
host := parts[0]
|
||||
port := parts[1]
|
||||
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||
|
||||
ctx, cancel = context.WithTimeout(ctx, 59*time.Second)
|
||||
defer cancel()
|
||||
conn, err := pgconn.Connect(ctx, connStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
pipeline := conn.StartPipeline(ctx)
|
||||
pipeline.SendPrepare("s1", "select 1", nil)
|
||||
pipeline.SendPrepare("s2", "select 2", nil)
|
||||
pipeline.SendPrepare("s3", "select 3", nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, err = pipeline.GetResults()
|
||||
require.Error(t, err)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func mustEncode(buf []byte, err error) []byte {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
@@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 4)
|
||||
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
||||
|
||||
@@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
|
||||
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
||||
dst = append(dst, a.Data...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
||||
|
||||
@@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 12)
|
||||
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
|
||||
dst = append(dst, src.Salt[:]...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationOk) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
||||
|
||||
for _, s := range src.AuthMechanisms {
|
||||
@@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||
}
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'R')
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Unmarshaler.
|
||||
|
||||
+21
-4
@@ -16,7 +16,8 @@ type Backend struct {
|
||||
// before it is actually transmitted (i.e. before Flush).
|
||||
tracer *tracer
|
||||
|
||||
wbuf []byte
|
||||
wbuf []byte
|
||||
encodeError error
|
||||
|
||||
// Frontend message flyweights
|
||||
bind Bind
|
||||
@@ -55,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
|
||||
return &Backend{cr: cr, w: w}
|
||||
}
|
||||
|
||||
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
|
||||
// called.
|
||||
// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
|
||||
// encountered will be returned from Flush.
|
||||
func (b *Backend) Send(msg BackendMessage) {
|
||||
if b.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(b.wbuf)
|
||||
b.wbuf = msg.Encode(b.wbuf)
|
||||
newBuf, err := msg.Encode(b.wbuf)
|
||||
if err != nil {
|
||||
b.encodeError = err
|
||||
return
|
||||
}
|
||||
b.wbuf = newBuf
|
||||
|
||||
if b.tracer != nil {
|
||||
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
|
||||
}
|
||||
@@ -67,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) {
|
||||
|
||||
// Flush writes any pending messages to the frontend (i.e. the client).
|
||||
func (b *Backend) Flush() error {
|
||||
if err := b.encodeError; err != nil {
|
||||
b.encodeError = nil
|
||||
b.wbuf = b.wbuf[:0]
|
||||
return &writeError{err: err, safeToRetry: true}
|
||||
}
|
||||
|
||||
n, err := b.w.Write(b.wbuf)
|
||||
|
||||
const maxLen = 1024
|
||||
|
||||
@@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, 12)
|
||||
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
|
||||
"username": "tester",
|
||||
},
|
||||
}
|
||||
dst := []byte{}
|
||||
dst = want.Encode(dst)
|
||||
dst, err := want.Encode([]byte{})
|
||||
require.NoError(t, err)
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push(dst)
|
||||
|
||||
+14
-7
@@ -5,7 +5,9 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
@@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Bind) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'B')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *Bind) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'B')
|
||||
|
||||
dst = append(dst, src.DestinationPortal...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.PreparedStatement...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
if len(src.ParameterFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many parameter format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
||||
for _, fc := range src.ParameterFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
if len(src.Parameters) > math.MaxUint16 {
|
||||
return nil, errors.New("too many parameters")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
||||
for _, p := range src.Parameters {
|
||||
if p == nil {
|
||||
@@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
|
||||
dst = append(dst, p...)
|
||||
}
|
||||
|
||||
if len(src.ResultFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many result format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
||||
for _, fc := range src.ResultFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *BindComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '2', 0, 0, 0, 4)
|
||||
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, '2', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package pgproto3_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Maximum allowed size.
|
||||
_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1 byte too big
|
||||
_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *CancelRequest) Encode(dst []byte) []byte {
|
||||
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
|
||||
dst = pgio.AppendInt32(dst, 16)
|
||||
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+3
-11
@@ -4,8 +4,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type Close struct {
|
||||
@@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Close) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *Close) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'C')
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CloseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '3', 0, 0, 0, 4)
|
||||
func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, '3', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -3,8 +3,6 @@ package pgproto3
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CommandComplete struct {
|
||||
@@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CommandComplete) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'C')
|
||||
dst = append(dst, src.CommandTag...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
@@ -44,19 +45,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'W')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'W')
|
||||
dst = append(dst, src.OverallFormat)
|
||||
if len(src.ColumnFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many column format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/andoma-go/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
@@ -13,6 +14,7 @@ func TestEncodeDecode(t *testing.T) {
|
||||
err := dstResp.Decode(srcBytes[5:])
|
||||
assert.NoError(t, err, "No errors on decode")
|
||||
dstBytes := []byte{}
|
||||
dstBytes = dstResp.Encode(dstBytes)
|
||||
dstBytes, err = dstResp.Encode(dstBytes)
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package pgproto3
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CopyData struct {
|
||||
@@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'd')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||
func (src *CopyData) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'd')
|
||||
dst = append(dst, src.Data...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyDone) Encode(dst []byte) []byte {
|
||||
return append(dst, 'c', 0, 0, 0, 4)
|
||||
func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'c', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+3
-11
@@ -3,8 +3,6 @@ package pgproto3
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type CopyFail struct {
|
||||
@@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyFail) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'f')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'f')
|
||||
dst = append(dst, src.Message...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
@@ -44,20 +45,19 @@ func (dst *CopyInResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyInResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'G')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'G')
|
||||
|
||||
dst = append(dst, src.OverallFormat)
|
||||
if len(src.ColumnFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many column format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
@@ -43,21 +44,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'H')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'H')
|
||||
|
||||
dst = append(dst, src.OverallFormat)
|
||||
|
||||
if len(src.ColumnFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many column format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
@@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *DataRow) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *DataRow) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'D')
|
||||
|
||||
if len(src.Values) > math.MaxUint16 {
|
||||
return nil, errors.New("too many values")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
||||
for _, v := range src.Values {
|
||||
if v == nil {
|
||||
@@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte {
|
||||
dst = append(dst, v...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+3
-11
@@ -4,8 +4,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type Describe struct {
|
||||
@@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Describe) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *Describe) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'D')
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, 'I', 0, 0, 0, 4)
|
||||
func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'I', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+64
-71
@@ -2,7 +2,6 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
)
|
||||
@@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ErrorResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, src.marshalBinary('E')...)
|
||||
func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'E')
|
||||
dst = src.appendFields(dst)
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte(typeByte)
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
|
||||
func (src *ErrorResponse) appendFields(dst []byte) []byte {
|
||||
if src.Severity != "" {
|
||||
buf.WriteByte('S')
|
||||
buf.WriteString(src.Severity)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'S')
|
||||
dst = append(dst, src.Severity...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.SeverityUnlocalized != "" {
|
||||
buf.WriteByte('V')
|
||||
buf.WriteString(src.SeverityUnlocalized)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'V')
|
||||
dst = append(dst, src.SeverityUnlocalized...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Code != "" {
|
||||
buf.WriteByte('C')
|
||||
buf.WriteString(src.Code)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'C')
|
||||
dst = append(dst, src.Code...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Message != "" {
|
||||
buf.WriteByte('M')
|
||||
buf.WriteString(src.Message)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'M')
|
||||
dst = append(dst, src.Message...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Detail != "" {
|
||||
buf.WriteByte('D')
|
||||
buf.WriteString(src.Detail)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'D')
|
||||
dst = append(dst, src.Detail...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Hint != "" {
|
||||
buf.WriteByte('H')
|
||||
buf.WriteString(src.Hint)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'H')
|
||||
dst = append(dst, src.Hint...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Position != 0 {
|
||||
buf.WriteByte('P')
|
||||
buf.WriteString(strconv.Itoa(int(src.Position)))
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'P')
|
||||
dst = append(dst, strconv.Itoa(int(src.Position))...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.InternalPosition != 0 {
|
||||
buf.WriteByte('p')
|
||||
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'p')
|
||||
dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.InternalQuery != "" {
|
||||
buf.WriteByte('q')
|
||||
buf.WriteString(src.InternalQuery)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'q')
|
||||
dst = append(dst, src.InternalQuery...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Where != "" {
|
||||
buf.WriteByte('W')
|
||||
buf.WriteString(src.Where)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'W')
|
||||
dst = append(dst, src.Where...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.SchemaName != "" {
|
||||
buf.WriteByte('s')
|
||||
buf.WriteString(src.SchemaName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 's')
|
||||
dst = append(dst, src.SchemaName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.TableName != "" {
|
||||
buf.WriteByte('t')
|
||||
buf.WriteString(src.TableName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 't')
|
||||
dst = append(dst, src.TableName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.ColumnName != "" {
|
||||
buf.WriteByte('c')
|
||||
buf.WriteString(src.ColumnName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'c')
|
||||
dst = append(dst, src.ColumnName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.DataTypeName != "" {
|
||||
buf.WriteByte('d')
|
||||
buf.WriteString(src.DataTypeName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'd')
|
||||
dst = append(dst, src.DataTypeName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.ConstraintName != "" {
|
||||
buf.WriteByte('n')
|
||||
buf.WriteString(src.ConstraintName)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'n')
|
||||
dst = append(dst, src.ConstraintName...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.File != "" {
|
||||
buf.WriteByte('F')
|
||||
buf.WriteString(src.File)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'F')
|
||||
dst = append(dst, src.File...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Line != 0 {
|
||||
buf.WriteByte('L')
|
||||
buf.WriteString(strconv.Itoa(int(src.Line)))
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'L')
|
||||
dst = append(dst, strconv.Itoa(int(src.Line))...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
if src.Routine != "" {
|
||||
buf.WriteByte('R')
|
||||
buf.WriteString(src.Routine)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 'R')
|
||||
dst = append(dst, src.Routine...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
|
||||
for k, v := range src.UnknownFields {
|
||||
buf.WriteByte(k)
|
||||
buf.WriteString(v)
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, k)
|
||||
dst = append(dst, v...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
|
||||
buf.WriteByte(0)
|
||||
dst = append(dst, 0)
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes()
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error {
|
||||
return fmt.Errorf("error generating query response: %w", err)
|
||||
}
|
||||
|
||||
buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
||||
buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
||||
{
|
||||
Name: []byte("fortune"),
|
||||
TableOID: 0,
|
||||
@@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error {
|
||||
TypeModifier: -1,
|
||||
Format: 0,
|
||||
},
|
||||
}}).Encode(nil)
|
||||
buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)
|
||||
buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
|
||||
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
|
||||
}}).Encode(nil))
|
||||
buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf))
|
||||
buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf))
|
||||
buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
|
||||
_, err = p.conn.Write(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing query response: %w", err)
|
||||
@@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error {
|
||||
|
||||
switch startupMessage.(type) {
|
||||
case *pgproto3.StartupMessage:
|
||||
buf := (&pgproto3.AuthenticationOk{}).Encode(nil)
|
||||
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
|
||||
buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))
|
||||
buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
|
||||
_, err = p.conn.Write(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending ready for query: %w", err)
|
||||
@@ -102,3 +102,10 @@ func (p *PgFortuneBackend) handleStartup() error {
|
||||
func (p *PgFortuneBackend) Close() error {
|
||||
return p.conn.Close()
|
||||
}
|
||||
|
||||
func mustEncode(buf []byte, err error) []byte {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
+3
-10
@@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Execute) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'E')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *Execute) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'E')
|
||||
dst = append(dst, src.Portal...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint32(dst, src.MaxRows)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+2
-2
@@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Flush) Encode(dst []byte) []byte {
|
||||
return append(dst, 'H', 0, 0, 0, 4)
|
||||
func (src *Flush) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'H', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+112
-25
@@ -18,7 +18,8 @@ type Frontend struct {
|
||||
// idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
|
||||
tracer *tracer
|
||||
|
||||
wbuf []byte
|
||||
wbuf []byte
|
||||
encodeError error
|
||||
|
||||
// Backend message flyweights
|
||||
authenticationOk AuthenticationOk
|
||||
@@ -64,16 +65,26 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
|
||||
return &Frontend{cr: cr, w: w}
|
||||
}
|
||||
|
||||
// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
|
||||
// called.
|
||||
// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error
|
||||
// encountered will be returned from Flush.
|
||||
//
|
||||
// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
|
||||
// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
|
||||
// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
|
||||
// behind an interface.
|
||||
func (f *Frontend) Send(msg FrontendMessage) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
@@ -81,6 +92,12 @@ func (f *Frontend) Send(msg FrontendMessage) {
|
||||
|
||||
// Flush writes any pending messages to the backend (i.e. the server).
|
||||
func (f *Frontend) Flush() error {
|
||||
if err := f.encodeError; err != nil {
|
||||
f.encodeError = nil
|
||||
f.wbuf = f.wbuf[:0]
|
||||
return &writeError{err: err, safeToRetry: true}
|
||||
}
|
||||
|
||||
if len(f.wbuf) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -116,71 +133,141 @@ func (f *Frontend) Untrace() {
|
||||
f.tracer = nil
|
||||
}
|
||||
|
||||
// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||
// error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendBind(msg *Bind) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||
// error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendParse(msg *Parse) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||
// error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendClose(msg *Close) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is
|
||||
// called. Any error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendDescribe(msg *Describe) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendExecute sends an Execute message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called.
|
||||
// Any error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendExecute(msg *Execute) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||
// error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendSync(msg *Sync) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until
|
||||
// Flush is called.
|
||||
// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any
|
||||
// error encountered will be returned from Flush.
|
||||
func (f *Frontend) SendQuery(msg *Query) {
|
||||
if f.encodeError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
newBuf, err := msg.Encode(f.wbuf)
|
||||
if err != nil {
|
||||
f.encodeError = err
|
||||
return
|
||||
}
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
@@ -71,15 +73,21 @@ func (dst *FunctionCall) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *FunctionCall) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'F')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
|
||||
func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'F')
|
||||
dst = pgio.AppendUint32(dst, src.Function)
|
||||
|
||||
if len(src.ArgFormatCodes) > math.MaxUint16 {
|
||||
return nil, errors.New("too many arg format codes")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
|
||||
for _, argFormatCode := range src.ArgFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, argFormatCode)
|
||||
}
|
||||
|
||||
if len(src.Arguments) > math.MaxUint16 {
|
||||
return nil, errors.New("too many arguments")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
|
||||
for _, argument := range src.Arguments {
|
||||
if argument == nil {
|
||||
@@ -90,6 +98,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte {
|
||||
}
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
@@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'V')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'V')
|
||||
|
||||
if src.Result == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
@@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, src.Result...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFunctionCall_EncodeDecode(t *testing.T) {
|
||||
@@ -30,7 +32,8 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
|
||||
Arguments: tt.fields.Arguments,
|
||||
ResultFormatCode: tt.fields.ResultFormatCode,
|
||||
}
|
||||
encoded := src.Encode([]byte{})
|
||||
encoded, err := src.Encode([]byte{})
|
||||
require.NoError(t, err)
|
||||
dst := &FunctionCall{}
|
||||
// Check the header
|
||||
msgTypeCode := encoded[0]
|
||||
@@ -44,7 +47,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
|
||||
t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded))
|
||||
}
|
||||
// Check decoding works as expected
|
||||
err := dst.Decode(encoded[5:])
|
||||
err = dst.Decode(encoded[5:])
|
||||
if err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
@@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *GSSEncRequest) Encode(dst []byte) []byte {
|
||||
func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) {
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
dst = pgio.AppendInt32(dst, gssEncReqNumber)
|
||||
return dst
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -2,8 +2,6 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type GSSResponse struct {
|
||||
@@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GSSResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
|
||||
func (g *GSSResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'p')
|
||||
dst = append(dst, g.Data...)
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+2
-2
@@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NoData) Encode(dst []byte) []byte {
|
||||
return append(dst, 'n', 0, 0, 0, 4)
|
||||
func (src *NoData) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'n', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NoticeResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
|
||||
func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'N')
|
||||
dst = (*ErrorResponse)(src).appendFields(dst)
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
@@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NotificationResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'A')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'A')
|
||||
dst = pgio.AppendUint32(dst, src.PID)
|
||||
dst = append(dst, src.Channel...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Payload...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
@@ -39,19 +41,18 @@ func (dst *ParameterDescription) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ParameterDescription) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 't')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 't')
|
||||
|
||||
if len(src.ParameterOIDs) > math.MaxUint16 {
|
||||
return nil, errors.New("too many parameter oids")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||
for _, oid := range src.ParameterOIDs {
|
||||
dst = pgio.AppendUint32(dst, oid)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -3,8 +3,6 @@ package pgproto3
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type ParameterStatus struct {
|
||||
@@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ParameterStatus) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'S')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'S')
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Value...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+8
-7
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
@@ -52,24 +54,23 @@ func (dst *Parse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Parse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'P')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *Parse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'P')
|
||||
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Query...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
if len(src.ParameterOIDs) > math.MaxUint16 {
|
||||
return nil, errors.New("too many parameter oids")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||
for _, oid := range src.ParameterOIDs {
|
||||
dst = pgio.AppendUint32(dst, oid)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ParseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '1', 0, 0, 0, 4)
|
||||
func (src *ParseComplete) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, '1', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -3,8 +3,6 @@ package pgproto3
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type PasswordMessage struct {
|
||||
@@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *PasswordMessage) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
|
||||
|
||||
func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'p')
|
||||
dst = append(dst, src.Password...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+27
-1
@@ -4,8 +4,14 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL
|
||||
// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff.
|
||||
const maxMessageBodyLen = (0x3fffffff - 1)
|
||||
|
||||
// Message is the interface implemented by an object that can decode and encode
|
||||
// a particular PostgreSQL message.
|
||||
type Message interface {
|
||||
@@ -14,7 +20,7 @@ type Message interface {
|
||||
Decode(data []byte) error
|
||||
|
||||
// Encode appends itself to dst and returns the new buffer.
|
||||
Encode(dst []byte) []byte
|
||||
Encode(dst []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// FrontendMessage is a message sent by the frontend (i.e. the client).
|
||||
@@ -92,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) {
|
||||
}
|
||||
return nil, errors.New("unknown protocol representation")
|
||||
}
|
||||
|
||||
// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to
|
||||
// dst. It returns the new buffer and the position of the message length placeholder.
|
||||
func beginMessage(dst []byte, t byte) ([]byte, int) {
|
||||
dst = append(dst, t)
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
return dst, sp
|
||||
}
|
||||
|
||||
// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to
|
||||
// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer.
|
||||
func finishMessage(dst []byte, sp int) ([]byte, error) {
|
||||
messageBodyLen := len(dst[sp:])
|
||||
if messageBodyLen > maxMessageBodyLen {
|
||||
return nil, errors.New("message body too large")
|
||||
}
|
||||
pgio.SetInt32(dst[sp:], int32(messageBodyLen))
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
package pgproto3
|
||||
|
||||
const MaxMessageBodyLen = maxMessageBodyLen
|
||||
@@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *PortalSuspended) Encode(dst []byte) []byte {
|
||||
return append(dst, 's', 0, 0, 0, 4)
|
||||
func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 's', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+3
-8
@@ -3,8 +3,6 @@ package pgproto3
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type Query struct {
|
||||
@@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Query) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'Q')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
|
||||
|
||||
func (src *Query) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'Q')
|
||||
dst = append(dst, src.String...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package pgproto3_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestQueryBiggerThanMaxMessageBodyLen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Maximum allowed size. 4 bytes for size and 1 byte for 0 terminated string.
|
||||
_, err := (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-5))}).Encode(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1 byte too big
|
||||
_, err = (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-4))}).Encode(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ReadyForQuery) Encode(dst []byte) []byte {
|
||||
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
|
||||
func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
@@ -99,11 +101,12 @@ func (dst *RowDescription) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *RowDescription) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'T')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *RowDescription) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'T')
|
||||
|
||||
if len(src.Fields) > math.MaxUint16 {
|
||||
return nil, errors.New("too many fields")
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
|
||||
for _, fd := range src.Fields {
|
||||
dst = append(dst, fd.Name...)
|
||||
@@ -117,9 +120,7 @@ func (src *RowDescription) Encode(dst []byte) []byte {
|
||||
dst = pgio.AppendInt16(dst, fd.Format)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -39,10 +39,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *SASLInitialResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'p')
|
||||
|
||||
dst = append(dst, []byte(src.AuthMechanism)...)
|
||||
dst = append(dst, 0)
|
||||
@@ -50,9 +48,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte {
|
||||
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -3,8 +3,6 @@ package pgproto3
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type SASLResponse struct {
|
||||
@@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *SASLResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||
|
||||
func (src *SASLResponse) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'p')
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *SSLRequest) Encode(dst []byte) []byte {
|
||||
func (src *SSLRequest) Encode(dst []byte) ([]byte, error) {
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
dst = pgio.AppendInt32(dst, sslRequestNumber)
|
||||
return dst
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *StartupMessage) Encode(dst []byte) []byte {
|
||||
func (src *StartupMessage) Encode(dst []byte) ([]byte, error) {
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
@@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte {
|
||||
}
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+2
-2
@@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Sync) Encode(dst []byte) []byte {
|
||||
return append(dst, 'S', 0, 0, 0, 4)
|
||||
func (src *Sync) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'S', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
@@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Terminate) Encode(dst []byte) []byte {
|
||||
return append(dst, 'X', 0, 0, 0, 4)
|
||||
func (src *Terminate) Encode(dst []byte) ([]byte, error) {
|
||||
return append(dst, 'X', 0, 0, 0, 4), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
|
||||
+3
-1
@@ -176,8 +176,10 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error {
|
||||
|
||||
bitLen := int32(binary.BigEndian.Uint32(src))
|
||||
rp := 4
|
||||
buf := make([]byte, len(src[rp:]))
|
||||
copy(buf, src[rp:])
|
||||
|
||||
return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true})
|
||||
return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true})
|
||||
}
|
||||
|
||||
type scanPlanTextAnyToBitsScanner struct{}
|
||||
|
||||
+2
-2
@@ -297,12 +297,12 @@ func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var n float64
|
||||
var n float32
|
||||
err := codecScan(c, m, oid, format, src, &n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return n, nil
|
||||
return float64(n), nil
|
||||
}
|
||||
|
||||
func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||
|
||||
@@ -25,6 +25,11 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
|
||||
case []byte:
|
||||
return encodePlanJSONCodecEitherFormatByteSlice{}
|
||||
|
||||
// Handle json.RawMessage specifically because if it is run through json.Marshal it may be mutated.
|
||||
// e.g. `{"foo": "bar"}` -> `{"foo":"bar"}`.
|
||||
case json.RawMessage:
|
||||
return encodePlanJSONCodecEitherFormatJSONRawMessage{}
|
||||
|
||||
// Cannot rely on driver.Valuer being handled later because anything can be marshalled.
|
||||
//
|
||||
// https://github.com/jackc/pgx/issues/1430
|
||||
@@ -79,6 +84,18 @@ func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (n
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
type encodePlanJSONCodecEitherFormatJSONRawMessage struct{}
|
||||
|
||||
func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
jsonBytes := value.(json.RawMessage)
|
||||
if jsonBytes == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf = append(buf, jsonBytes...)
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
type encodePlanJSONCodecEitherFormatMarshal struct{}
|
||||
|
||||
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
|
||||
+122
@@ -0,0 +1,122 @@
|
||||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type LtreeCodec struct{}
|
||||
|
||||
func (l LtreeCodec) FormatSupported(format int16) bool {
|
||||
return format == TextFormatCode || format == BinaryFormatCode
|
||||
}
|
||||
|
||||
// PreferredFormat returns the preferred format.
|
||||
func (l LtreeCodec) PreferredFormat() int16 {
|
||||
return TextFormatCode
|
||||
}
|
||||
|
||||
// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be
|
||||
// found then nil is returned.
|
||||
func (l LtreeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||
switch format {
|
||||
case TextFormatCode:
|
||||
return (TextCodec)(l).PlanEncode(m, oid, format, value)
|
||||
case BinaryFormatCode:
|
||||
switch value.(type) {
|
||||
case string:
|
||||
return encodeLtreeCodecBinaryString{}
|
||||
case []byte:
|
||||
return encodeLtreeCodecBinaryByteSlice{}
|
||||
case TextValuer:
|
||||
return encodeLtreeCodecBinaryTextValuer{}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type encodeLtreeCodecBinaryString struct{}
|
||||
|
||||
func (encodeLtreeCodecBinaryString) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
ltree := value.(string)
|
||||
buf = append(buf, 1)
|
||||
return append(buf, ltree...), nil
|
||||
}
|
||||
|
||||
type encodeLtreeCodecBinaryByteSlice struct{}
|
||||
|
||||
func (encodeLtreeCodecBinaryByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
ltree := value.([]byte)
|
||||
buf = append(buf, 1)
|
||||
return append(buf, ltree...), nil
|
||||
}
|
||||
|
||||
type encodeLtreeCodecBinaryTextValuer struct{}
|
||||
|
||||
func (encodeLtreeCodecBinaryTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
t, err := value.(TextValuer).TextValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !t.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf = append(buf, 1)
|
||||
return append(buf, t.String...), nil
|
||||
}
|
||||
|
||||
// PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If
|
||||
// no plan can be found then nil is returned.
|
||||
func (l LtreeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||
switch format {
|
||||
case TextFormatCode:
|
||||
return (TextCodec)(l).PlanScan(m, oid, format, target)
|
||||
case BinaryFormatCode:
|
||||
switch target.(type) {
|
||||
case *string:
|
||||
return scanPlanBinaryLtreeToString{}
|
||||
case TextScanner:
|
||||
return scanPlanBinaryLtreeToTextScanner{}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type scanPlanBinaryLtreeToString struct{}
|
||||
|
||||
func (scanPlanBinaryLtreeToString) Scan(src []byte, target any) error {
|
||||
version := src[0]
|
||||
if version != 1 {
|
||||
return fmt.Errorf("unsupported ltree version %d", version)
|
||||
}
|
||||
|
||||
p := (target).(*string)
|
||||
*p = string(src[1:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type scanPlanBinaryLtreeToTextScanner struct{}
|
||||
|
||||
func (scanPlanBinaryLtreeToTextScanner) Scan(src []byte, target any) error {
|
||||
version := src[0]
|
||||
if version != 1 {
|
||||
return fmt.Errorf("unsupported ltree version %d", version)
|
||||
}
|
||||
|
||||
scanner := (target).(TextScanner)
|
||||
return scanner.ScanText(Text{String: string(src[1:]), Valid: true})
|
||||
}
|
||||
|
||||
// DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface.
|
||||
func (l LtreeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
return (TextCodec)(l).DecodeDatabaseSQLValue(m, oid, format, src)
|
||||
}
|
||||
|
||||
// DecodeValue returns src decoded into its default format.
|
||||
func (l LtreeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||
return (TextCodec)(l).DecodeValue(m, oid, format, src)
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/andoma-go/pgx/v5/pgtype"
|
||||
"github.com/andoma-go/pgx/v5/pgxtest"
|
||||
)
|
||||
|
||||
func TestLtreeCodec(t *testing.T) {
|
||||
skipCockroachDB(t, "Server does not support type ltree")
|
||||
|
||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "ltree", []pgxtest.ValueRoundTripTest{
|
||||
{
|
||||
Param: "A.B.C",
|
||||
Result: new(string),
|
||||
Test: isExpectedEq("A.B.C"),
|
||||
},
|
||||
{
|
||||
Param: pgtype.Text{String: "", Valid: true},
|
||||
Result: new(pgtype.Text),
|
||||
Test: isExpectedEq(pgtype.Text{String: "", Valid: true}),
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -48,4 +48,23 @@ func TestMacaddrCodec(t *testing.T) {
|
||||
},
|
||||
{nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))},
|
||||
})
|
||||
|
||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "macaddr8", []pgxtest.ValueRoundTripTest{
|
||||
{
|
||||
mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"),
|
||||
new(net.HardwareAddr),
|
||||
isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")),
|
||||
},
|
||||
{
|
||||
"01:23:45:67:89:ab:01:08",
|
||||
new(net.HardwareAddr),
|
||||
isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")),
|
||||
},
|
||||
{
|
||||
mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"),
|
||||
new(string),
|
||||
isExpectedEq("01:23:45:67:89:ab:01:08"),
|
||||
},
|
||||
{nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))},
|
||||
})
|
||||
}
|
||||
|
||||
+4
-1
@@ -41,6 +41,7 @@ const (
|
||||
CircleOID = 718
|
||||
CircleArrayOID = 719
|
||||
UnknownOID = 705
|
||||
Macaddr8OID = 774
|
||||
MacaddrOID = 829
|
||||
InetOID = 869
|
||||
BoolArrayOID = 1000
|
||||
@@ -81,6 +82,8 @@ const (
|
||||
IntervalOID = 1186
|
||||
IntervalArrayOID = 1187
|
||||
NumericArrayOID = 1231
|
||||
TimetzOID = 1266
|
||||
TimetzArrayOID = 1270
|
||||
BitOID = 1560
|
||||
BitArrayOID = 1561
|
||||
VarbitOID = 1562
|
||||
@@ -559,7 +562,7 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex
|
||||
}
|
||||
}
|
||||
|
||||
if nextDstType != nil && dstValue.Type() != nextDstType {
|
||||
if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) {
|
||||
return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package pgtype
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
@@ -69,6 +70,7 @@ func initDefaultMap() {
|
||||
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
|
||||
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "macaddr8", OID: Macaddr8OID, Codec: MacaddrCodec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}})
|
||||
@@ -173,6 +175,7 @@ func initDefaultMap() {
|
||||
registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz")
|
||||
registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval")
|
||||
registerDefaultPgTypeVariants[string](defaultMap, "text")
|
||||
registerDefaultPgTypeVariants[json.RawMessage](defaultMap, "json")
|
||||
registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea")
|
||||
|
||||
registerDefaultPgTypeVariants[net.IP](defaultMap, "inet")
|
||||
|
||||
@@ -35,6 +35,7 @@ func init() {
|
||||
// Test for renamed types
|
||||
type _string string
|
||||
type _bool bool
|
||||
type _uint8 uint8
|
||||
type _int8 int8
|
||||
type _int16 int16
|
||||
type _int16Slice []int16
|
||||
@@ -453,6 +454,14 @@ func TestMapScanNullToWrongType(t *testing.T) {
|
||||
assert.False(t, pn.Valid)
|
||||
}
|
||||
|
||||
func TestScanToSliceOfRenamedUint8(t *testing.T) {
|
||||
m := pgtype.NewMap()
|
||||
var ruint8 []_uint8
|
||||
err := m.Scan(pgtype.Int2ArrayOID, pgx.TextFormatCode, []byte("{2,4}"), &ruint8)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []_uint8{2, 4}, ruint8)
|
||||
}
|
||||
|
||||
func TestMapScanTextToBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -537,6 +546,14 @@ func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1763
|
||||
func TestMapEncodeRawJSONIntoUnknownOID(t *testing.T) {
|
||||
m := pgtype.NewMap()
|
||||
buf, err := m.Encode(0, pgtype.TextFormatCode, json.RawMessage(`{"foo": "bar"}`), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte(`{"foo": "bar"}`), buf)
|
||||
}
|
||||
|
||||
func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) {
|
||||
m := pgtype.NewMap()
|
||||
src := []byte{0, 0, 0, 42}
|
||||
|
||||
+13
-1
@@ -52,7 +52,19 @@ func parseUUID(src string) (dst [16]byte, err error) {
|
||||
|
||||
// encodeUUID converts a uuid byte array to UUID standard string form.
|
||||
func encodeUUID(src [16]byte) string {
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16])
|
||||
var buf [36]byte
|
||||
|
||||
hex.Encode(buf[0:8], src[:4])
|
||||
buf[8] = '-'
|
||||
hex.Encode(buf[9:13], src[4:6])
|
||||
buf[13] = '-'
|
||||
hex.Encode(buf[14:18], src[6:8])
|
||||
buf[18] = '-'
|
||||
hex.Encode(buf[19:23], src[8:10])
|
||||
buf[23] = '-'
|
||||
hex.Encode(buf[24:], src[10:])
|
||||
|
||||
return string(buf[:])
|
||||
}
|
||||
|
||||
// Scan implements the database/sql Scanner interface.
|
||||
|
||||
@@ -417,12 +417,10 @@ type CollectableRow interface {
|
||||
// RowToFunc is a function that scans or otherwise converts row to a T.
|
||||
type RowToFunc[T any] func(row CollectableRow) (T, error)
|
||||
|
||||
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
|
||||
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
|
||||
// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
|
||||
func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
|
||||
defer rows.Close()
|
||||
|
||||
slice := []T{}
|
||||
|
||||
for rows.Next() {
|
||||
value, err := fn(rows)
|
||||
if err != nil {
|
||||
@@ -438,6 +436,11 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
|
||||
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
|
||||
return AppendRows([]T{}, rows, fn)
|
||||
}
|
||||
|
||||
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
|
||||
// CollectOneRow is to CollectRows as QueryRow is to Query.
|
||||
func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
|
||||
|
||||
@@ -175,6 +175,21 @@ func TestCollectRows(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollectRowsEmpty(t *testing.T) {
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
rows, _ := conn.Query(ctx, `select n from generate_series(1, 0) n`)
|
||||
numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) {
|
||||
var n int32
|
||||
err := row.Scan(&n)
|
||||
return n, err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, numbers)
|
||||
|
||||
assert.Empty(t, numbers)
|
||||
})
|
||||
}
|
||||
|
||||
// This example uses CollectRows with a manually written collector function. In most cases RowTo, RowToAddrOf,
|
||||
// RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used.
|
||||
func ExampleCollectRows() {
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
[ req ]
|
||||
distinguished_name = dn
|
||||
[ dn ]
|
||||
commonName = ca
|
||||
[ ext ]
|
||||
basicConstraints =CA:TRUE,pathlen:0
|
||||
@@ -0,0 +1,187 @@
|
||||
// Generates a CA, server certificate, and encrypted client certificate for testing pgx.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create the CA
|
||||
ca := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "pgx-root-ca",
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(20, 0, 0),
|
||||
IsCA: true,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
caKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caKey.PublicKey, caKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = writePrivateKey("ca.key", caKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = writeCertificate("ca.pem", caBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create a server certificate signed by the CA for localhost.
|
||||
serverCert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "localhost",
|
||||
},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(20, 0, 0),
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
}
|
||||
|
||||
serverCertPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
serverBytes, err := x509.CreateCertificate(rand.Reader, serverCert, ca, &serverCertPrivKey.PublicKey, caKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = writePrivateKey("localhost.key", serverCertPrivKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = writeCertificate("localhost.crt", serverBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create a client certificate signed by the CA and encrypted.
|
||||
clientCert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(3),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "pgx_sslcert",
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(20, 0, 0),
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
}
|
||||
|
||||
clientCertPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
clientBytes, err := x509.CreateCertificate(rand.Reader, clientCert, ca, &clientCertPrivKey.PublicKey, caKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
writeEncryptedPrivateKey("pgx_sslcert.key", clientCertPrivKey, "certpw")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
writeCertificate("pgx_sslcert.crt", clientBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func writePrivateKey(path string, privateKey *rsa.PrivateKey) error {
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writePrivateKey: %w", err)
|
||||
}
|
||||
|
||||
err = pem.Encode(file, &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("writePrivateKey: %w", err)
|
||||
}
|
||||
|
||||
err = file.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("writePrivateKey: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeEncryptedPrivateKey(path string, privateKey *rsa.PrivateKey, password string) error {
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writeEncryptedPrivateKey: %w", err)
|
||||
}
|
||||
|
||||
block, err := x509.EncryptPEMBlock(rand.Reader, "CERTIFICATE", x509.MarshalPKCS1PrivateKey(privateKey), []byte(password), x509.PEMCipher3DES)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writeEncryptedPrivateKey: %w", err)
|
||||
}
|
||||
|
||||
err = pem.Encode(file, block)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writeEncryptedPrivateKey: %w", err)
|
||||
}
|
||||
|
||||
err = file.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("writeEncryptedPrivateKey: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func writeCertificate(path string, certBytes []byte) error {
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writeCertificate: %w", err)
|
||||
}
|
||||
|
||||
err = pem.Encode(file, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certBytes,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("writeCertificate: %w", err)
|
||||
}
|
||||
|
||||
err = file.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("writeCertificate: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
[ req ]
|
||||
default_bits = 2048
|
||||
distinguished_name = dn
|
||||
req_extensions = v3_req
|
||||
prompt = no
|
||||
[ dn ]
|
||||
commonName = localhost
|
||||
[ v3_req ]
|
||||
subjectAltName = @alt_names
|
||||
keyUsage = digitalSignature
|
||||
extendedKeyUsage = serverAuth
|
||||
[alt_names]
|
||||
DNS.1 = localhost
|
||||
@@ -1,9 +0,0 @@
|
||||
[ req ]
|
||||
default_bits = 2048
|
||||
distinguished_name = dn
|
||||
req_extensions = v3_req
|
||||
prompt = no
|
||||
[ dn ]
|
||||
commonName = pgx_sslcert
|
||||
[ v3_req ]
|
||||
keyUsage = digitalSignature
|
||||
@@ -1,5 +1,6 @@
|
||||
-- Create extensions and types.
|
||||
create extension hstore;
|
||||
create extension ltree;
|
||||
create domain uint64 as numeric(20,0);
|
||||
|
||||
-- Create users for different types of connections and authentication.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user