Use extracted packages with Go modules
This commit is contained in:
@@ -56,7 +56,7 @@ database/sql compatibility layer for pgx. pgx can be used as a normal database/s
|
|||||||
|
|
||||||
Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver.
|
Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver.
|
||||||
|
|
||||||
## github.com/jackc/pgx/pgproto3
|
## github.com/jackc/pgproto3
|
||||||
|
|
||||||
pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling.
|
pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling.
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package pgx
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|||||||
+1
-1
@@ -5,8 +5,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"github.com/jackc/pgx/pgconn"
|
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,89 +0,0 @@
|
|||||||
package chunkreader
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ChunkReader struct {
|
|
||||||
r io.Reader
|
|
||||||
|
|
||||||
buf []byte
|
|
||||||
rp, wp int // buf read position and write position
|
|
||||||
|
|
||||||
options Options
|
|
||||||
}
|
|
||||||
|
|
||||||
type Options struct {
|
|
||||||
MinBufLen int // Minimum buffer length
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewChunkReader(r io.Reader) *ChunkReader {
|
|
||||||
cr, err := NewChunkReaderEx(r, Options{})
|
|
||||||
if err != nil {
|
|
||||||
panic("default options can't be bad")
|
|
||||||
}
|
|
||||||
|
|
||||||
return cr
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) {
|
|
||||||
if options.MinBufLen == 0 {
|
|
||||||
options.MinBufLen = 4096
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ChunkReader{
|
|
||||||
r: r,
|
|
||||||
buf: make([]byte, options.MinBufLen),
|
|
||||||
options: options,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next returns buf filled with the next n bytes. If an error occurs, buf will
|
|
||||||
// be nil.
|
|
||||||
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
|
|
||||||
// n bytes already in buf
|
|
||||||
if (r.wp - r.rp) >= n {
|
|
||||||
buf = r.buf[r.rp : r.rp+n]
|
|
||||||
r.rp += n
|
|
||||||
return buf, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// available space in buf is less than n
|
|
||||||
if len(r.buf) < n {
|
|
||||||
r.copyBufContents(r.newBuf(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
// buf is large enough, but need to shift filled area to start to make enough contiguous space
|
|
||||||
minReadCount := n - (r.wp - r.rp)
|
|
||||||
if (len(r.buf) - r.wp) < minReadCount {
|
|
||||||
newBuf := r.newBuf(n)
|
|
||||||
r.copyBufContents(newBuf)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.appendAtLeast(minReadCount); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
buf = r.buf[r.rp : r.rp+n]
|
|
||||||
r.rp += n
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ChunkReader) appendAtLeast(fillLen int) error {
|
|
||||||
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
|
|
||||||
r.wp += n
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ChunkReader) newBuf(size int) []byte {
|
|
||||||
if size < r.options.MinBufLen {
|
|
||||||
size = r.options.MinBufLen
|
|
||||||
}
|
|
||||||
return make([]byte, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ChunkReader) copyBufContents(dest []byte) {
|
|
||||||
r.wp = copy(dest, r.buf[r.rp:r.wp])
|
|
||||||
r.rp = 0
|
|
||||||
r.buf = dest
|
|
||||||
}
|
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
package chunkreader
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
|
|
||||||
server := &bytes.Buffer{}
|
|
||||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
src := []byte{1, 2, 3, 4}
|
|
||||||
server.Write(src)
|
|
||||||
|
|
||||||
n1, err := r.Next(2)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if bytes.Compare(n1, src[0:2]) != 0 {
|
|
||||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
|
|
||||||
}
|
|
||||||
|
|
||||||
n2, err := r.Next(2)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if bytes.Compare(n2, src[2:4]) != 0 {
|
|
||||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
|
|
||||||
}
|
|
||||||
|
|
||||||
if bytes.Compare(r.buf, src) != 0 {
|
|
||||||
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
|
|
||||||
}
|
|
||||||
if r.rp != 4 {
|
|
||||||
t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp)
|
|
||||||
}
|
|
||||||
if r.wp != 4 {
|
|
||||||
t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) {
|
|
||||||
server := &bytes.Buffer{}
|
|
||||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
|
|
||||||
server.Write(src)
|
|
||||||
|
|
||||||
n1, err := r.Next(5)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if bytes.Compare(n1, src[0:5]) != 0 {
|
|
||||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1)
|
|
||||||
}
|
|
||||||
if len(r.buf) != 5 {
|
|
||||||
t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChunkReaderDoesNotReuseBuf(t *testing.T) {
|
|
||||||
server := &bytes.Buffer{}
|
|
||||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
|
|
||||||
server.Write(src)
|
|
||||||
|
|
||||||
n1, err := r.Next(4)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if bytes.Compare(n1, src[0:4]) != 0 {
|
|
||||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1)
|
|
||||||
}
|
|
||||||
|
|
||||||
n2, err := r.Next(4)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if bytes.Compare(n2, src[4:8]) != 0 {
|
|
||||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2)
|
|
||||||
}
|
|
||||||
|
|
||||||
if bytes.Compare(n1, src[0:4]) != 0 {
|
|
||||||
t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx/pgproto3"
|
"github.com/jackc/pgproto3"
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -11,8 +11,8 @@ import (
|
|||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"github.com/jackc/pgx/pgconn"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func createConnPool(t *testing.T, maxConnections int) *pgx.ConnPool {
|
func createConnPool(t *testing.T, maxConnections int) *pgx.ConnPool {
|
||||||
|
|||||||
+1
-1
@@ -10,8 +10,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"github.com/jackc/pgx/pgconn"
|
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|||||||
+1
-1
@@ -6,7 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -6,8 +6,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"github.com/jackc/pgx/pgconn"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+2
-2
@@ -3,8 +3,8 @@ package pgx
|
|||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/jackc/pgx/pgproto3"
|
"github.com/jackc/pgproto3"
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
module github.com/jackc/pgx
|
||||||
|
|
||||||
|
go 1.12
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/cockroachdb/apd v1.1.0
|
||||||
|
github.com/jackc/pgconn v0.0.0-20190330221323-ed7d91dc9873
|
||||||
|
github.com/jackc/pgio v1.0.0
|
||||||
|
github.com/jackc/pgproto3 v1.0.0
|
||||||
|
github.com/pkg/errors v0.8.1
|
||||||
|
github.com/satori/go.uuid v1.2.0
|
||||||
|
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24
|
||||||
|
github.com/stretchr/testify v1.3.0
|
||||||
|
)
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
|
||||||
|
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
|
||||||
|
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
||||||
|
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
||||||
|
github.com/jackc/pgconn v0.0.0-20190330221323-ed7d91dc9873 h1:M68R77AKFS7dub7R7WgJ9D6yiNuExOYhBuGtazlbr10=
|
||||||
|
github.com/jackc/pgconn v0.0.0-20190330221323-ed7d91dc9873/go.mod h1:8Bzf8vzi/ZpcgLgrq8IUHjZX4ZU+Hf6N6/AJ85+fDeE=
|
||||||
|
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
||||||
|
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
|
github.com/jackc/pgproto3 v1.0.0 h1:25tUmlES7eyD96oYaUHc1dLOFbgcJtFzCdnOOoqmA1I=
|
||||||
|
github.com/jackc/pgproto3 v1.0.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
||||||
|
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||||
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
|
||||||
|
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||||
|
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE=
|
||||||
|
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
+1
-1
@@ -4,8 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"github.com/jackc/pgx/pgconn"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"github.com/jackc/pgx/pgconn"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLargeObjects(t *testing.T) {
|
func TestLargeObjects(t *testing.T) {
|
||||||
|
|||||||
+1
-1
@@ -6,7 +6,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,101 +0,0 @@
|
|||||||
package pgconn_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgconn"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func BenchmarkConnect(b *testing.B) {
|
|
||||||
benchmarks := []struct {
|
|
||||||
name string
|
|
||||||
env string
|
|
||||||
}{
|
|
||||||
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
|
||||||
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, bm := range benchmarks {
|
|
||||||
b.Run(bm.name, func(b *testing.B) {
|
|
||||||
connString := os.Getenv(bm.env)
|
|
||||||
if connString == "" {
|
|
||||||
b.Skipf("Skipping due to missing environment variable %v", bm.env)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
conn, err := pgconn.Connect(context.Background(), connString)
|
|
||||||
require.Nil(b, err)
|
|
||||||
|
|
||||||
err = conn.Close(context.Background())
|
|
||||||
require.Nil(b, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkExec(b *testing.B) {
|
|
||||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
require.Nil(b, err)
|
|
||||||
defer closeConn(b, conn)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll()
|
|
||||||
require.Nil(b, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkExecPossibleToCancel(b *testing.B) {
|
|
||||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
require.Nil(b, err)
|
|
||||||
defer closeConn(b, conn)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll()
|
|
||||||
require.Nil(b, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkExecPrepared(b *testing.B) {
|
|
||||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
require.Nil(b, err)
|
|
||||||
defer closeConn(b, conn)
|
|
||||||
|
|
||||||
_, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
|
||||||
require.Nil(b, err)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).Read()
|
|
||||||
require.Nil(b, result.Err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
|
|
||||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
require.Nil(b, err)
|
|
||||||
defer closeConn(b, conn)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
|
||||||
require.Nil(b, err)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
|
||||||
require.Nil(b, result.Err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,501 +0,0 @@
|
|||||||
package pgconn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"math"
|
|
||||||
"net"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"os/user"
|
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgpassfile"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
|
||||||
|
|
||||||
// Config is the settings used to establish a connection to a PostgreSQL server.
|
|
||||||
type Config struct {
|
|
||||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
|
||||||
Port uint16
|
|
||||||
Database string
|
|
||||||
User string
|
|
||||||
Password string
|
|
||||||
TLSConfig *tls.Config // nil disables TLS
|
|
||||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
|
||||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
|
||||||
|
|
||||||
Fallbacks []*FallbackConfig
|
|
||||||
|
|
||||||
// AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that
|
|
||||||
// server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This
|
|
||||||
// allows implementing high availability behavior such as libpq does with target_session_attrs.
|
|
||||||
AfterConnectFunc AfterConnectFunc
|
|
||||||
|
|
||||||
// OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context
|
|
||||||
// is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a
|
|
||||||
// query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire
|
|
||||||
// protocol do not support this cancellation method.
|
|
||||||
//
|
|
||||||
// It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be
|
|
||||||
// called whether it was successful or not. If an error occurs the connection should be closed. The connection must be
|
|
||||||
// in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read
|
|
||||||
// the connection until a ready for query message is received.
|
|
||||||
OnContextCancel func(*ContextCancel)
|
|
||||||
|
|
||||||
// OnNotice is a callback function called when a notice response is received.
|
|
||||||
OnNotice NoticeHandler
|
|
||||||
|
|
||||||
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
|
||||||
OnNotification NotificationHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
|
||||||
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
|
||||||
type FallbackConfig struct {
|
|
||||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
|
||||||
Port uint16
|
|
||||||
TLSConfig *tls.Config // nil disables TLS
|
|
||||||
}
|
|
||||||
|
|
||||||
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
|
||||||
// net.Dial.
|
|
||||||
func NetworkAddress(host string, port uint16) (network, address string) {
|
|
||||||
if strings.HasPrefix(host, "/") {
|
|
||||||
network = "unix"
|
|
||||||
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
|
||||||
} else {
|
|
||||||
network = "tcp"
|
|
||||||
address = fmt.Sprintf("%s:%d", host, port)
|
|
||||||
}
|
|
||||||
return network, address
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same
|
|
||||||
// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN.
|
|
||||||
// It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the
|
|
||||||
// .pgpass file.
|
|
||||||
//
|
|
||||||
// # Example DSN
|
|
||||||
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
|
||||||
//
|
|
||||||
// # Example URL
|
|
||||||
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
|
||||||
//
|
|
||||||
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
|
|
||||||
// values that will be tried in order. This can be used as part of a high availability system. See
|
|
||||||
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
|
||||||
//
|
|
||||||
// # Example URL
|
|
||||||
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
|
||||||
//
|
|
||||||
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
|
|
||||||
// via database URL or DSN:
|
|
||||||
//
|
|
||||||
// PGHOST
|
|
||||||
// PGPORT
|
|
||||||
// PGDATABASE
|
|
||||||
// PGUSER
|
|
||||||
// PGPASSWORD
|
|
||||||
// PGPASSFILE
|
|
||||||
// PGSSLMODE
|
|
||||||
// PGSSLCERT
|
|
||||||
// PGSSLKEY
|
|
||||||
// PGSSLROOTCERT
|
|
||||||
// PGAPPNAME
|
|
||||||
// PGCONNECT_TIMEOUT
|
|
||||||
// PGTARGETSESSIONATTRS
|
|
||||||
//
|
|
||||||
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
|
|
||||||
//
|
|
||||||
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
|
|
||||||
// usually but not always the environment variable name downcased and without the "PG" prefix.
|
|
||||||
//
|
|
||||||
// Important TLS Security Notes:
|
|
||||||
//
|
|
||||||
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
|
|
||||||
// not set.
|
|
||||||
//
|
|
||||||
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
|
|
||||||
// security each sslmode provides.
|
|
||||||
//
|
|
||||||
// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger
|
|
||||||
// security guarantees than it would with libpq. Do not rely on this behavior as it
|
|
||||||
// may be possible to match libpq in the future. If you need full security use
|
|
||||||
// "verify-full".
|
|
||||||
func ParseConfig(connString string) (*Config, error) {
|
|
||||||
settings := defaultSettings()
|
|
||||||
addEnvSettings(settings)
|
|
||||||
|
|
||||||
if connString != "" {
|
|
||||||
// connString may be a database URL or a DSN
|
|
||||||
if strings.HasPrefix(connString, "postgres://") {
|
|
||||||
err := addURLSettings(settings, connString)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err := addDSNSettings(settings, connString)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
config := &Config{
|
|
||||||
Database: settings["database"],
|
|
||||||
User: settings["user"],
|
|
||||||
Password: settings["password"],
|
|
||||||
RuntimeParams: make(map[string]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
if connectTimeout, present := settings["connect_timeout"]; present {
|
|
||||||
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
config.DialFunc = dialFunc
|
|
||||||
} else {
|
|
||||||
defaultDialer := makeDefaultDialer()
|
|
||||||
config.DialFunc = defaultDialer.DialContext
|
|
||||||
}
|
|
||||||
|
|
||||||
notRuntimeParams := map[string]struct{}{
|
|
||||||
"host": struct{}{},
|
|
||||||
"port": struct{}{},
|
|
||||||
"database": struct{}{},
|
|
||||||
"user": struct{}{},
|
|
||||||
"password": struct{}{},
|
|
||||||
"passfile": struct{}{},
|
|
||||||
"connect_timeout": struct{}{},
|
|
||||||
"sslmode": struct{}{},
|
|
||||||
"sslkey": struct{}{},
|
|
||||||
"sslcert": struct{}{},
|
|
||||||
"sslrootcert": struct{}{},
|
|
||||||
"target_session_attrs": struct{}{},
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range settings {
|
|
||||||
if _, present := notRuntimeParams[k]; present {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
config.RuntimeParams[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
fallbacks := []*FallbackConfig{}
|
|
||||||
|
|
||||||
hosts := strings.Split(settings["host"], ",")
|
|
||||||
ports := strings.Split(settings["port"], ",")
|
|
||||||
|
|
||||||
for i, host := range hosts {
|
|
||||||
var portStr string
|
|
||||||
if i < len(ports) {
|
|
||||||
portStr = ports[i]
|
|
||||||
} else {
|
|
||||||
portStr = ports[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
port, err := parsePort(portStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid port: %v", settings["port"])
|
|
||||||
}
|
|
||||||
|
|
||||||
var tlsConfigs []*tls.Config
|
|
||||||
|
|
||||||
// Ignore TLS settings if Unix domain socket like libpq
|
|
||||||
if network, _ := NetworkAddress(host, port); network == "unix" {
|
|
||||||
tlsConfigs = append(tlsConfigs, nil)
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
tlsConfigs, err = configTLS(settings)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tlsConfig := range tlsConfigs {
|
|
||||||
fallbacks = append(fallbacks, &FallbackConfig{
|
|
||||||
Host: host,
|
|
||||||
Port: port,
|
|
||||||
TLSConfig: tlsConfig,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Host = fallbacks[0].Host
|
|
||||||
config.Port = fallbacks[0].Port
|
|
||||||
config.TLSConfig = fallbacks[0].TLSConfig
|
|
||||||
config.Fallbacks = fallbacks[1:]
|
|
||||||
|
|
||||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
|
||||||
if err == nil {
|
|
||||||
if config.Password == "" {
|
|
||||||
host := config.Host
|
|
||||||
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
|
||||||
host = "localhost"
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if settings["target_session_attrs"] == "read-write" {
|
|
||||||
config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite
|
|
||||||
} else if settings["target_session_attrs"] != "any" {
|
|
||||||
return nil, fmt.Errorf("unknown target_session_attrs value %v", settings["target_session_attrs"])
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultSettings() map[string]string {
|
|
||||||
settings := make(map[string]string)
|
|
||||||
|
|
||||||
settings["host"] = defaultHost()
|
|
||||||
settings["port"] = "5432"
|
|
||||||
|
|
||||||
// Default to the OS user name. Purposely ignoring err getting user name from
|
|
||||||
// OS. The client application will simply have to specify the user in that
|
|
||||||
// case (which they typically will be doing anyway).
|
|
||||||
user, err := user.Current()
|
|
||||||
if err == nil {
|
|
||||||
settings["user"] = user.Username
|
|
||||||
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
|
||||||
}
|
|
||||||
|
|
||||||
settings["target_session_attrs"] = "any"
|
|
||||||
|
|
||||||
return settings
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
|
||||||
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
|
||||||
// checks the existence of common locations.
|
|
||||||
func defaultHost() string {
|
|
||||||
candidatePaths := []string{
|
|
||||||
"/var/run/postgresql", // Debian
|
|
||||||
"/private/tmp", // OSX - homebrew
|
|
||||||
"/tmp", // standard PostgreSQL
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, path := range candidatePaths {
|
|
||||||
if _, err := os.Stat(path); err == nil {
|
|
||||||
return path
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "localhost"
|
|
||||||
}
|
|
||||||
|
|
||||||
func addEnvSettings(settings map[string]string) {
|
|
||||||
nameMap := map[string]string{
|
|
||||||
"PGHOST": "host",
|
|
||||||
"PGPORT": "port",
|
|
||||||
"PGDATABASE": "database",
|
|
||||||
"PGUSER": "user",
|
|
||||||
"PGPASSWORD": "password",
|
|
||||||
"PGPASSFILE": "passfile",
|
|
||||||
"PGAPPNAME": "application_name",
|
|
||||||
"PGCONNECT_TIMEOUT": "connect_timeout",
|
|
||||||
"PGSSLMODE": "sslmode",
|
|
||||||
"PGSSLKEY": "sslkey",
|
|
||||||
"PGSSLCERT": "sslcert",
|
|
||||||
"PGSSLROOTCERT": "sslrootcert",
|
|
||||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
|
||||||
}
|
|
||||||
|
|
||||||
for envname, realname := range nameMap {
|
|
||||||
value := os.Getenv(envname)
|
|
||||||
if value != "" {
|
|
||||||
settings[realname] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func addURLSettings(settings map[string]string, connString string) error {
|
|
||||||
url, err := url.Parse(connString)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if url.User != nil {
|
|
||||||
settings["user"] = url.User.Username()
|
|
||||||
if password, present := url.User.Password(); present {
|
|
||||||
settings["password"] = password
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
|
||||||
var hosts []string
|
|
||||||
var ports []string
|
|
||||||
for _, host := range strings.Split(url.Host, ",") {
|
|
||||||
parts := strings.SplitN(host, ":", 2)
|
|
||||||
if parts[0] != "" {
|
|
||||||
hosts = append(hosts, parts[0])
|
|
||||||
}
|
|
||||||
if len(parts) == 2 {
|
|
||||||
ports = append(ports, parts[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(hosts) > 0 {
|
|
||||||
settings["host"] = strings.Join(hosts, ",")
|
|
||||||
}
|
|
||||||
if len(ports) > 0 {
|
|
||||||
settings["port"] = strings.Join(ports, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
database := strings.TrimLeft(url.Path, "/")
|
|
||||||
if database != "" {
|
|
||||||
settings["database"] = database
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range url.Query() {
|
|
||||||
settings[k] = v[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
|
||||||
|
|
||||||
func addDSNSettings(settings map[string]string, s string) error {
|
|
||||||
m := dsnRegexp.FindAllStringSubmatch(s, -1)
|
|
||||||
|
|
||||||
for _, b := range m {
|
|
||||||
settings[b[1]] = b[2]
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type pgTLSArgs struct {
|
|
||||||
sslMode string
|
|
||||||
sslRootCert string
|
|
||||||
sslCert string
|
|
||||||
sslKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
|
||||||
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
|
||||||
// "prefer" allow fallback.
|
|
||||||
func configTLS(settings map[string]string) ([]*tls.Config, error) {
|
|
||||||
host := settings["host"]
|
|
||||||
sslmode := settings["sslmode"]
|
|
||||||
sslrootcert := settings["sslrootcert"]
|
|
||||||
sslcert := settings["sslcert"]
|
|
||||||
sslkey := settings["sslkey"]
|
|
||||||
|
|
||||||
// Match libpq default behavior
|
|
||||||
if sslmode == "" {
|
|
||||||
sslmode = "prefer"
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfig := &tls.Config{}
|
|
||||||
|
|
||||||
switch sslmode {
|
|
||||||
case "disable":
|
|
||||||
return []*tls.Config{nil}, nil
|
|
||||||
case "allow", "prefer":
|
|
||||||
tlsConfig.InsecureSkipVerify = true
|
|
||||||
case "require":
|
|
||||||
tlsConfig.InsecureSkipVerify = sslrootcert == ""
|
|
||||||
case "verify-ca", "verify-full":
|
|
||||||
tlsConfig.ServerName = host
|
|
||||||
default:
|
|
||||||
return nil, errors.New("sslmode is invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
if sslrootcert != "" {
|
|
||||||
caCertPool := x509.NewCertPool()
|
|
||||||
|
|
||||||
caPath := sslrootcert
|
|
||||||
caCert, err := ioutil.ReadFile(caPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "unable to read CA file %q", caPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
|
||||||
return nil, errors.Wrap(err, "unable to add CA to cert pool")
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfig.RootCAs = caCertPool
|
|
||||||
tlsConfig.ClientCAs = caCertPool
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
|
||||||
return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sslcert != "" && sslkey != "" {
|
|
||||||
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "unable to read cert")
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch sslmode {
|
|
||||||
case "allow":
|
|
||||||
return []*tls.Config{nil, tlsConfig}, nil
|
|
||||||
case "prefer":
|
|
||||||
return []*tls.Config{tlsConfig, nil}, nil
|
|
||||||
case "require", "verify-ca", "verify-full":
|
|
||||||
return []*tls.Config{tlsConfig}, nil
|
|
||||||
default:
|
|
||||||
panic("BUG: bad sslmode should already have been caught")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parsePort(s string) (uint16, error) {
|
|
||||||
port, err := strconv.ParseUint(s, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if port < 1 || port > math.MaxUint16 {
|
|
||||||
return 0, errors.New("outside range")
|
|
||||||
}
|
|
||||||
return uint16(port), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeDefaultDialer() *net.Dialer {
|
|
||||||
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
|
|
||||||
timeout, err := strconv.ParseInt(s, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if timeout < 0 {
|
|
||||||
return nil, errors.New("negative timeout")
|
|
||||||
}
|
|
||||||
|
|
||||||
d := makeDefaultDialer()
|
|
||||||
d.Timeout = time.Duration(timeout) * time.Second
|
|
||||||
return d.DialContext, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible
|
|
||||||
// target_session_attrs=read-write.
|
|
||||||
func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
|
||||||
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
|
||||||
if result.Err != nil {
|
|
||||||
return result.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(result.Rows[0][0]) == "on" {
|
|
||||||
return errors.New("read only connection")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,562 +0,0 @@
|
|||||||
package pgconn_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
|
||||||
"os/user"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgconn"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseConfig(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var osUserName string
|
|
||||||
osUser, err := user.Current()
|
|
||||||
if err == nil {
|
|
||||||
osUserName = osUser.Username
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
connString string
|
|
||||||
config *pgconn.Config
|
|
||||||
}{
|
|
||||||
// Test all sslmodes
|
|
||||||
{
|
|
||||||
name: "sslmode not set (prefer)",
|
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
Fallbacks: []*pgconn.FallbackConfig{
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode disable",
|
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode allow",
|
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
Fallbacks: []*pgconn.FallbackConfig{
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode prefer",
|
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
Fallbacks: []*pgconn.FallbackConfig{
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode require",
|
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode verify-ca",
|
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{ServerName: "localhost"},
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode verify-full",
|
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{ServerName: "localhost"},
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "database url everything",
|
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{
|
|
||||||
"application_name": "pgxtest",
|
|
||||||
"search_path": "myschema",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "database url missing password",
|
|
||||||
connString: "postgres://jack@localhost:5432/mydb?sslmode=disable",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "database url missing user and password",
|
|
||||||
connString: "postgres://localhost:5432/mydb?sslmode=disable",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: osUserName,
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "database url missing port",
|
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "database url unix domain socket host",
|
|
||||||
connString: "postgres:///foo?host=/tmp",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: osUserName,
|
|
||||||
Host: "/tmp",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "foo",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DSN everything",
|
|
||||||
connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{
|
|
||||||
"application_name": "pgxtest",
|
|
||||||
"search_path": "myschema",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "URL multiple hosts",
|
|
||||||
connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "foo",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
Fallbacks: []*pgconn.FallbackConfig{
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "bar",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "baz",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "URL multiple hosts and ports",
|
|
||||||
connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "foo",
|
|
||||||
Port: 1,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
Fallbacks: []*pgconn.FallbackConfig{
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "bar",
|
|
||||||
Port: 2,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "baz",
|
|
||||||
Port: 3,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DSN multiple hosts one port",
|
|
||||||
connString: "user=jack password=secret host=foo,bar,baz port=5432 database=mydb sslmode=disable",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "foo",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
Fallbacks: []*pgconn.FallbackConfig{
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "bar",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "baz",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DSN multiple hosts multiple ports",
|
|
||||||
connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 database=mydb sslmode=disable",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "foo",
|
|
||||||
Port: 1,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
Fallbacks: []*pgconn.FallbackConfig{
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "bar",
|
|
||||||
Port: 2,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "baz",
|
|
||||||
Port: 3,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple hosts and fallback tsl",
|
|
||||||
connString: "user=jack password=secret host=foo,bar,baz database=mydb sslmode=prefer",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "foo",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
Fallbacks: []*pgconn.FallbackConfig{
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "foo",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "bar",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
}},
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "bar",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "baz",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
}},
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "baz",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "target_session_attrs",
|
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
config, err := pgconn.ParseConfig(tt.connString)
|
|
||||||
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
|
|
||||||
if !assert.NotNil(t, expected) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !assert.NotNil(t, actual) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
|
|
||||||
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
|
|
||||||
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
|
||||||
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
|
|
||||||
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
|
||||||
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
|
||||||
|
|
||||||
// Can't test function equality, so just test that they are set or not.
|
|
||||||
assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName)
|
|
||||||
|
|
||||||
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
|
||||||
if expected.TLSConfig != nil {
|
|
||||||
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
|
|
||||||
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
|
|
||||||
for i := range expected.Fallbacks {
|
|
||||||
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
|
|
||||||
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
|
|
||||||
|
|
||||||
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
|
|
||||||
if expected.Fallbacks[i].TLSConfig != nil {
|
|
||||||
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
|
|
||||||
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseConfigEnvLibpq(t *testing.T) {
|
|
||||||
var osUserName string
|
|
||||||
osUser, err := user.Current()
|
|
||||||
if err == nil {
|
|
||||||
osUserName = osUser.Username
|
|
||||||
}
|
|
||||||
|
|
||||||
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
|
|
||||||
|
|
||||||
savedEnv := make(map[string]string)
|
|
||||||
for _, n := range pgEnvvars {
|
|
||||||
savedEnv[n] = os.Getenv(n)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
for k, v := range savedEnv {
|
|
||||||
err := os.Setenv(k, v)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to restore environment: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
envvars map[string]string
|
|
||||||
config *pgconn.Config
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
// not testing no environment at all as that would use default host and that can vary.
|
|
||||||
name: "PGHOST only",
|
|
||||||
envvars: map[string]string{"PGHOST": "123.123.123.123"},
|
|
||||||
config: &pgconn.Config{
|
|
||||||
User: osUserName,
|
|
||||||
Host: "123.123.123.123",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
Fallbacks: []*pgconn.FallbackConfig{
|
|
||||||
&pgconn.FallbackConfig{
|
|
||||||
Host: "123.123.123.123",
|
|
||||||
Port: 5432,
|
|
||||||
TLSConfig: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "All non-TLS environment",
|
|
||||||
envvars: map[string]string{
|
|
||||||
"PGHOST": "123.123.123.123",
|
|
||||||
"PGPORT": "7777",
|
|
||||||
"PGDATABASE": "foo",
|
|
||||||
"PGUSER": "bar",
|
|
||||||
"PGPASSWORD": "baz",
|
|
||||||
"PGCONNECT_TIMEOUT": "10",
|
|
||||||
"PGSSLMODE": "disable",
|
|
||||||
"PGAPPNAME": "pgxtest",
|
|
||||||
},
|
|
||||||
config: &pgconn.Config{
|
|
||||||
Host: "123.123.123.123",
|
|
||||||
Port: 7777,
|
|
||||||
Database: "foo",
|
|
||||||
User: "bar",
|
|
||||||
Password: "baz",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
for _, n := range pgEnvvars {
|
|
||||||
err := os.Unsetenv(n)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range tt.envvars {
|
|
||||||
err := os.Setenv(k, v)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, err := pgconn.ParseConfig("")
|
|
||||||
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseConfigReadsPgPassfile(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tf, err := ioutil.TempFile("", "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
defer tf.Close()
|
|
||||||
defer os.Remove(tf.Name())
|
|
||||||
|
|
||||||
_, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name())
|
|
||||||
expected := &pgconn.Config{
|
|
||||||
User: "curly",
|
|
||||||
Password: "nyuknyuknyuk",
|
|
||||||
Host: "test1",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "curlydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
actual, err := pgconn.ParseConfig(connString)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assertConfigsEqual(t, expected, actual, "passfile")
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
// Package pgconn is a low-level PostgreSQL database driver.
|
|
||||||
/*
|
|
||||||
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
|
|
||||||
nearly the same level is the C library libpq.
|
|
||||||
|
|
||||||
Establishing a Connection
|
|
||||||
|
|
||||||
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
|
|
||||||
libpq style environment variables.
|
|
||||||
|
|
||||||
Executing a Query
|
|
||||||
|
|
||||||
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
|
|
||||||
reads all rows into memory.
|
|
||||||
|
|
||||||
Executing Multiple Queries in a Single Round Trip
|
|
||||||
|
|
||||||
Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query
|
|
||||||
result. The ReadAll method reads all query results into memory.
|
|
||||||
|
|
||||||
Context Support
|
|
||||||
|
|
||||||
All potentially blocking operations take a context.Context. If a context is canceled while a query is in progress the
|
|
||||||
method immediately returns. In the background a cancel request will be sent to the PostgreSQL server. If the
|
|
||||||
cancellation fails or hangs for more than a short time (approximately 15 seconds) the connection will be closed. It is
|
|
||||||
safe to use the connection while this background cancellation is in progress. Any calls will block until the
|
|
||||||
cancellation and resynchronization is complete (and those calls can be aborted by a context cancellation).
|
|
||||||
*/
|
|
||||||
package pgconn
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
package pgconn_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgconn"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func closeConn(t testing.TB, conn *pgconn.PgConn) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
defer cancel()
|
|
||||||
require.Nil(t, conn.Close(ctx))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do a simple query to ensure the connection is still usable
|
|
||||||
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
require.Nil(t, result.Err)
|
|
||||||
assert.Equal(t, 3, len(result.Rows))
|
|
||||||
assert.Equal(t, "1", string(result.Rows[0][0]))
|
|
||||||
assert.Equal(t, "2", string(result.Rows[1][0]))
|
|
||||||
assert.Equal(t, "3", string(result.Rows[2][0]))
|
|
||||||
}
|
|
||||||
-1407
File diff suppressed because it is too large
Load Diff
@@ -1,120 +0,0 @@
|
|||||||
package pgconn_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"math/rand"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgconn"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConnStress(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer closeConn(t, pgConn)
|
|
||||||
|
|
||||||
actionCount := 100
|
|
||||||
if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" {
|
|
||||||
stressFactor, err := strconv.ParseInt(s, 10, 64)
|
|
||||||
require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR")
|
|
||||||
actionCount *= int(stressFactor)
|
|
||||||
}
|
|
||||||
|
|
||||||
setupStressDB(t, pgConn)
|
|
||||||
|
|
||||||
actions := []struct {
|
|
||||||
name string
|
|
||||||
fn func(*pgconn.PgConn) error
|
|
||||||
}{
|
|
||||||
{"Exec Select", stressExecSelect},
|
|
||||||
{"ExecParams Select", stressExecParamsSelect},
|
|
||||||
{"Batch", stressBatch},
|
|
||||||
{"ExecCanceled", stressExecSelectCanceled},
|
|
||||||
{"ExecParamsCanceled", stressExecParamsSelectCanceled},
|
|
||||||
{"BatchCanceled", stressBatchCanceled},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < actionCount; i++ {
|
|
||||||
action := actions[rand.Intn(len(actions))]
|
|
||||||
err := action.fn(pgConn)
|
|
||||||
require.Nilf(t, err, "%d: %s", i, action.name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) {
|
|
||||||
_, err := pgConn.Exec(context.Background(), `
|
|
||||||
create temporary table widgets(
|
|
||||||
id serial primary key,
|
|
||||||
name varchar not null,
|
|
||||||
description text,
|
|
||||||
creation_time timestamptz default now()
|
|
||||||
);
|
|
||||||
|
|
||||||
insert into widgets(name, description) values
|
|
||||||
('Foo', 'bar'),
|
|
||||||
('baz', 'Something really long Something really long Something really long Something really long Something really long'),
|
|
||||||
('a', 'b')`).ReadAll()
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func stressExecSelect(pgConn *pgconn.PgConn) error {
|
|
||||||
_, err := pgConn.Exec(context.Background(), "select * from widgets").ReadAll()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func stressExecParamsSelect(pgConn *pgconn.PgConn) error {
|
|
||||||
result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
|
|
||||||
return result.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
func stressBatch(pgConn *pgconn.PgConn) error {
|
|
||||||
batch := &pgconn.Batch{}
|
|
||||||
|
|
||||||
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
|
|
||||||
batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
|
|
||||||
_, err := pgConn.ExecBatch(context.Background(), batch).ReadAll()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func stressExecSelectCanceled(pgConn *pgconn.PgConn) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
|
||||||
_, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets").ReadAll()
|
|
||||||
cancel()
|
|
||||||
if err != context.DeadlineExceeded {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
|
||||||
result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
|
|
||||||
cancel()
|
|
||||||
if result.Err != context.DeadlineExceeded {
|
|
||||||
return result.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func stressBatchCanceled(pgConn *pgconn.PgConn) error {
|
|
||||||
batch := &pgconn.Batch{}
|
|
||||||
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
|
|
||||||
batch.ExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
|
||||||
_, err := pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
||||||
cancel()
|
|
||||||
if err != context.DeadlineExceeded {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +0,0 @@
|
|||||||
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
|
||||||
/*
|
|
||||||
pgio provides functions for appending integers to a []byte while doing byte
|
|
||||||
order conversion.
|
|
||||||
*/
|
|
||||||
package pgio
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
package pgio
|
|
||||||
|
|
||||||
import "encoding/binary"
|
|
||||||
|
|
||||||
func AppendUint16(buf []byte, n uint16) []byte {
|
|
||||||
wp := len(buf)
|
|
||||||
buf = append(buf, 0, 0)
|
|
||||||
binary.BigEndian.PutUint16(buf[wp:], n)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func AppendUint32(buf []byte, n uint32) []byte {
|
|
||||||
wp := len(buf)
|
|
||||||
buf = append(buf, 0, 0, 0, 0)
|
|
||||||
binary.BigEndian.PutUint32(buf[wp:], n)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func AppendUint64(buf []byte, n uint64) []byte {
|
|
||||||
wp := len(buf)
|
|
||||||
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
|
|
||||||
binary.BigEndian.PutUint64(buf[wp:], n)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func AppendInt16(buf []byte, n int16) []byte {
|
|
||||||
return AppendUint16(buf, uint16(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func AppendInt32(buf []byte, n int32) []byte {
|
|
||||||
return AppendUint32(buf, uint32(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func AppendInt64(buf []byte, n int64) []byte {
|
|
||||||
return AppendUint64(buf, uint64(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetInt32(buf []byte, n int32) {
|
|
||||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
|
||||||
}
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
package pgio
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAppendUint16NilBuf(t *testing.T) {
|
|
||||||
buf := AppendUint16(nil, 1)
|
|
||||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
|
||||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAppendUint16EmptyBuf(t *testing.T) {
|
|
||||||
buf := []byte{}
|
|
||||||
buf = AppendUint16(buf, 1)
|
|
||||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
|
||||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) {
|
|
||||||
buf := make([]byte, 0, 4)
|
|
||||||
AppendUint16(buf, 1)
|
|
||||||
buf = buf[0:2]
|
|
||||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
|
||||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAppendUint32NilBuf(t *testing.T) {
|
|
||||||
buf := AppendUint32(nil, 1)
|
|
||||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
|
||||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAppendUint32EmptyBuf(t *testing.T) {
|
|
||||||
buf := []byte{}
|
|
||||||
buf = AppendUint32(buf, 1)
|
|
||||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
|
||||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) {
|
|
||||||
buf := make([]byte, 0, 4)
|
|
||||||
AppendUint32(buf, 1)
|
|
||||||
buf = buf[0:4]
|
|
||||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
|
||||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAppendUint64NilBuf(t *testing.T) {
|
|
||||||
buf := AppendUint64(nil, 1)
|
|
||||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
|
||||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAppendUint64EmptyBuf(t *testing.T) {
|
|
||||||
buf := []byte{}
|
|
||||||
buf = AppendUint64(buf, 1)
|
|
||||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
|
||||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) {
|
|
||||||
buf := make([]byte, 0, 8)
|
|
||||||
AppendUint64(buf, 1)
|
|
||||||
buf = buf[0:8]
|
|
||||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
|
||||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+2
-2
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgproto3"
|
"github.com/jackc/pgproto3"
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ func (s *Server) ServeOne() error {
|
|||||||
|
|
||||||
s.Close()
|
s.Close()
|
||||||
|
|
||||||
backend, err := pgproto3.NewBackend(conn, conn)
|
backend, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -1,109 +0,0 @@
|
|||||||
package pgpassfile
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Entry represents a line in a PG passfile.
|
|
||||||
type Entry struct {
|
|
||||||
Hostname string
|
|
||||||
Port string
|
|
||||||
Database string
|
|
||||||
Username string
|
|
||||||
Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Passfile is the in memory data structure representing a PG passfile.
|
|
||||||
type Passfile struct {
|
|
||||||
Entries []*Entry
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadPassfile reads the file at path and parses it into a Passfile.
|
|
||||||
func ReadPassfile(path string) (*Passfile, error) {
|
|
||||||
f, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
return ParsePassfile(f)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParsePassfile reads r and parses it into a Passfile.
|
|
||||||
func ParsePassfile(r io.Reader) (*Passfile, error) {
|
|
||||||
passfile := &Passfile{}
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(r)
|
|
||||||
for scanner.Scan() {
|
|
||||||
entry := parseLine(scanner.Text())
|
|
||||||
if entry != nil {
|
|
||||||
passfile.Entries = append(passfile.Entries, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return passfile, scanner.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Match (not colons or escaped colon or escaped backslash)+. Essentially gives a split on unescaped
|
|
||||||
// colon.
|
|
||||||
var colonSplitterRegexp = regexp.MustCompile("(([^:]|(\\:)))+")
|
|
||||||
|
|
||||||
// var colonSplitterRegexp = regexp.MustCompile("((?:[^:]|(?:\\:)|(?:\\\\))+)")
|
|
||||||
|
|
||||||
// parseLine parses a line into an *Entry. It returns nil on comment lines or any other unparsable
|
|
||||||
// line.
|
|
||||||
func parseLine(line string) *Entry {
|
|
||||||
const (
|
|
||||||
tmpBackslash = "\r"
|
|
||||||
tmpColon = "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "#") {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
line = strings.Replace(line, `\\`, tmpBackslash, -1)
|
|
||||||
line = strings.Replace(line, `\:`, tmpColon, -1)
|
|
||||||
|
|
||||||
parts := strings.Split(line, ":")
|
|
||||||
if len(parts) != 5 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unescape escaped colons and backslashes
|
|
||||||
for i := range parts {
|
|
||||||
parts[i] = strings.Replace(parts[i], tmpBackslash, `\`, -1)
|
|
||||||
parts[i] = strings.Replace(parts[i], tmpColon, `:`, -1)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Entry{
|
|
||||||
Hostname: parts[0],
|
|
||||||
Port: parts[1],
|
|
||||||
Database: parts[2],
|
|
||||||
Username: parts[3],
|
|
||||||
Password: parts[4],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindPassword finds the password for the provided hostname, port, database, and username. For a
|
|
||||||
// Unix domain socket hostname must be set to "localhost". An empty string will be returned if no
|
|
||||||
// match is found.
|
|
||||||
//
|
|
||||||
// See https://www.postgresql.org/docs/current/libpq-pgpass.html for more password file information.
|
|
||||||
func (pf *Passfile) FindPassword(hostname, port, database, username string) (password string) {
|
|
||||||
for _, e := range pf.Entries {
|
|
||||||
if (e.Hostname == "*" || e.Hostname == hostname) &&
|
|
||||||
(e.Port == "*" || e.Port == port) &&
|
|
||||||
(e.Database == "*" || e.Database == database) &&
|
|
||||||
(e.Username == "*" || e.Username == username) {
|
|
||||||
return e.Password
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
package pgpassfile
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func unescape(s string) string {
|
|
||||||
s = strings.Replace(s, `\:`, `:`, -1)
|
|
||||||
s = strings.Replace(s, `\\`, `\`, -1)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
var passfile = [][]string{
|
|
||||||
{"test1", "5432", "larrydb", "larry", "whatstheidea"},
|
|
||||||
{"test1", "5432", "moedb", "moe", "imbecile"},
|
|
||||||
{"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"},
|
|
||||||
{"test2", "5432", "*", "shemp", "heymoe"},
|
|
||||||
{"test2", "5432", "*", "*", `test\\ing\:`},
|
|
||||||
{"localhost", "*", "*", "*", "sesam"},
|
|
||||||
{"test3", "*", "", "", "swordfish"}, // user will be filled later
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsePassFile(t *testing.T) {
|
|
||||||
buf := bytes.NewBufferString(`# A comment
|
|
||||||
test1:5432:larrydb:larry:whatstheidea
|
|
||||||
test1:5432:moedb:moe:imbecile
|
|
||||||
test1:5432:curlydb:curly:nyuknyuknyuk
|
|
||||||
test2:5432:*:shemp:heymoe
|
|
||||||
test2:5432:*:*:test\\ing\:
|
|
||||||
localhost:*:*:*:sesam
|
|
||||||
`)
|
|
||||||
|
|
||||||
passfile, err := ParsePassfile(buf)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
assert.Len(t, passfile.Entries, 6)
|
|
||||||
|
|
||||||
assert.Equal(t, "whatstheidea", passfile.FindPassword("test1", "5432", "larrydb", "larry"))
|
|
||||||
assert.Equal(t, "imbecile", passfile.FindPassword("test1", "5432", "moedb", "moe"))
|
|
||||||
assert.Equal(t, `test\ing:`, passfile.FindPassword("test2", "5432", "something", "else"))
|
|
||||||
assert.Equal(t, "sesam", passfile.FindPassword("localhost", "9999", "foo", "bare"))
|
|
||||||
|
|
||||||
assert.Equal(t, "", passfile.FindPassword("wrong", "5432", "larrydb", "larry"))
|
|
||||||
assert.Equal(t, "", passfile.FindPassword("test1", "wrong", "larrydb", "larry"))
|
|
||||||
assert.Equal(t, "", passfile.FindPassword("test1", "5432", "wrong", "larry"))
|
|
||||||
assert.Equal(t, "", passfile.FindPassword("test1", "5432", "larrydb", "wrong"))
|
|
||||||
}
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
AuthTypeOk = 0
|
|
||||||
AuthTypeCleartextPassword = 3
|
|
||||||
AuthTypeMD5Password = 5
|
|
||||||
)
|
|
||||||
|
|
||||||
type Authentication struct {
|
|
||||||
Type uint32
|
|
||||||
|
|
||||||
// MD5Password fields
|
|
||||||
Salt [4]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*Authentication) Backend() {}
|
|
||||||
|
|
||||||
func (dst *Authentication) Decode(src []byte) error {
|
|
||||||
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
|
|
||||||
|
|
||||||
switch dst.Type {
|
|
||||||
case AuthTypeOk:
|
|
||||||
case AuthTypeCleartextPassword:
|
|
||||||
case AuthTypeMD5Password:
|
|
||||||
copy(dst.Salt[:], src[4:8])
|
|
||||||
default:
|
|
||||||
return errors.Errorf("unknown authentication type: %d", dst.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Authentication) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'R')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
dst = pgio.AppendUint32(dst, src.Type)
|
|
||||||
|
|
||||||
switch src.Type {
|
|
||||||
case AuthTypeMD5Password:
|
|
||||||
dst = append(dst, src.Salt[:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/chunkreader"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Backend struct {
|
|
||||||
cr *chunkreader.ChunkReader
|
|
||||||
w io.Writer
|
|
||||||
|
|
||||||
// Frontend message flyweights
|
|
||||||
bind Bind
|
|
||||||
_close Close
|
|
||||||
copyFail CopyFail
|
|
||||||
describe Describe
|
|
||||||
execute Execute
|
|
||||||
flush Flush
|
|
||||||
parse Parse
|
|
||||||
passwordMessage PasswordMessage
|
|
||||||
query Query
|
|
||||||
startupMessage StartupMessage
|
|
||||||
sync Sync
|
|
||||||
terminate Terminate
|
|
||||||
|
|
||||||
bodyLen int
|
|
||||||
msgType byte
|
|
||||||
partialMsg bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
|
|
||||||
cr := chunkreader.NewChunkReader(r)
|
|
||||||
return &Backend{cr: cr, w: w}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Backend) Send(msg BackendMessage) error {
|
|
||||||
_, err := b.w.Write(msg.Encode(nil))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
|
||||||
buf, err := b.cr.Next(4)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
|
||||||
|
|
||||||
buf, err = b.cr.Next(msgSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = b.startupMessage.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &b.startupMessage, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Backend) Receive() (FrontendMessage, error) {
|
|
||||||
if !b.partialMsg {
|
|
||||||
header, err := b.cr.Next(5)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
b.msgType = header[0]
|
|
||||||
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
|
||||||
b.partialMsg = true
|
|
||||||
}
|
|
||||||
|
|
||||||
var msg FrontendMessage
|
|
||||||
switch b.msgType {
|
|
||||||
case 'B':
|
|
||||||
msg = &b.bind
|
|
||||||
case 'C':
|
|
||||||
msg = &b._close
|
|
||||||
case 'D':
|
|
||||||
msg = &b.describe
|
|
||||||
case 'E':
|
|
||||||
msg = &b.execute
|
|
||||||
case 'f':
|
|
||||||
msg = &b.copyFail
|
|
||||||
case 'H':
|
|
||||||
msg = &b.flush
|
|
||||||
case 'P':
|
|
||||||
msg = &b.parse
|
|
||||||
case 'p':
|
|
||||||
msg = &b.passwordMessage
|
|
||||||
case 'Q':
|
|
||||||
msg = &b.query
|
|
||||||
case 'S':
|
|
||||||
msg = &b.sync
|
|
||||||
case 'X':
|
|
||||||
msg = &b.terminate
|
|
||||||
default:
|
|
||||||
return nil, errors.Errorf("unknown message type: %c", b.msgType)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgBody, err := b.cr.Next(b.bodyLen)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
b.partialMsg = false
|
|
||||||
|
|
||||||
err = msg.Decode(msgBody)
|
|
||||||
return msg, err
|
|
||||||
}
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BackendKeyData struct {
|
|
||||||
ProcessID uint32
|
|
||||||
SecretKey uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*BackendKeyData) Backend() {}
|
|
||||||
|
|
||||||
func (dst *BackendKeyData) Decode(src []byte) error {
|
|
||||||
if len(src) != 8 {
|
|
||||||
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
|
|
||||||
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'K')
|
|
||||||
dst = pgio.AppendUint32(dst, 12)
|
|
||||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
|
||||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *BackendKeyData) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
ProcessID uint32
|
|
||||||
SecretKey uint32
|
|
||||||
}{
|
|
||||||
Type: "BackendKeyData",
|
|
||||||
ProcessID: src.ProcessID,
|
|
||||||
SecretKey: src.SecretKey,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package pgproto3_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgproto3"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBackendReceiveInterrupted(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
server := &interruptReader{}
|
|
||||||
server.push([]byte{'Q', 0, 0, 0, 6})
|
|
||||||
|
|
||||||
backend, err := pgproto3.NewBackend(server, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := backend.Receive()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected err")
|
|
||||||
}
|
|
||||||
if msg != nil {
|
|
||||||
t.Fatalf("did not expect msg, but %v", msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
server.push([]byte{'I', 0})
|
|
||||||
|
|
||||||
msg, err = backend.Receive()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
|
|
||||||
t.Fatalf("unexpected msg: %v", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BigEndianBuf [8]byte
|
|
||||||
|
|
||||||
func (b BigEndianBuf) Int16(n int16) []byte {
|
|
||||||
buf := b[0:2]
|
|
||||||
binary.BigEndian.PutUint16(buf, uint16(n))
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b BigEndianBuf) Uint16(n uint16) []byte {
|
|
||||||
buf := b[0:2]
|
|
||||||
binary.BigEndian.PutUint16(buf, n)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b BigEndianBuf) Int32(n int32) []byte {
|
|
||||||
buf := b[0:4]
|
|
||||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b BigEndianBuf) Uint32(n uint32) []byte {
|
|
||||||
buf := b[0:4]
|
|
||||||
binary.BigEndian.PutUint32(buf, n)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b BigEndianBuf) Int64(n int64) []byte {
|
|
||||||
buf := b[0:8]
|
|
||||||
binary.BigEndian.PutUint64(buf, uint64(n))
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
@@ -1,171 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Bind struct {
|
|
||||||
DestinationPortal string
|
|
||||||
PreparedStatement string
|
|
||||||
ParameterFormatCodes []int16
|
|
||||||
Parameters [][]byte
|
|
||||||
ResultFormatCodes []int16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*Bind) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *Bind) Decode(src []byte) error {
|
|
||||||
*dst = Bind{}
|
|
||||||
|
|
||||||
idx := bytes.IndexByte(src, 0)
|
|
||||||
if idx < 0 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
|
||||||
}
|
|
||||||
dst.DestinationPortal = string(src[:idx])
|
|
||||||
rp := idx + 1
|
|
||||||
|
|
||||||
idx = bytes.IndexByte(src[rp:], 0)
|
|
||||||
if idx < 0 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
|
||||||
}
|
|
||||||
dst.PreparedStatement = string(src[rp : rp+idx])
|
|
||||||
rp += idx + 1
|
|
||||||
|
|
||||||
if len(src[rp:]) < 2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
|
||||||
}
|
|
||||||
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
|
||||||
rp += 2
|
|
||||||
|
|
||||||
if parameterFormatCodeCount > 0 {
|
|
||||||
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
|
|
||||||
|
|
||||||
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
|
||||||
}
|
|
||||||
for i := 0; i < parameterFormatCodeCount; i++ {
|
|
||||||
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
|
||||||
rp += 2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(src[rp:]) < 2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
|
||||||
}
|
|
||||||
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
|
|
||||||
rp += 2
|
|
||||||
|
|
||||||
if parameterCount > 0 {
|
|
||||||
dst.Parameters = make([][]byte, parameterCount)
|
|
||||||
|
|
||||||
for i := 0; i < parameterCount; i++ {
|
|
||||||
if len(src[rp:]) < 4 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
|
||||||
}
|
|
||||||
|
|
||||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
|
||||||
rp += 4
|
|
||||||
|
|
||||||
// null
|
|
||||||
if msgSize == -1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(src[rp:]) < msgSize {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.Parameters[i] = src[rp : rp+msgSize]
|
|
||||||
rp += msgSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(src[rp:]) < 2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
|
||||||
}
|
|
||||||
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
|
||||||
rp += 2
|
|
||||||
|
|
||||||
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
|
|
||||||
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
|
||||||
}
|
|
||||||
for i := 0; i < resultFormatCodeCount; i++ {
|
|
||||||
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
|
||||||
rp += 2
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Bind) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'B')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.DestinationPortal...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
dst = append(dst, src.PreparedStatement...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
|
||||||
for _, fc := range src.ParameterFormatCodes {
|
|
||||||
dst = pgio.AppendInt16(dst, fc)
|
|
||||||
}
|
|
||||||
|
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
|
||||||
for _, p := range src.Parameters {
|
|
||||||
if p == nil {
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
dst = pgio.AppendInt32(dst, int32(len(p)))
|
|
||||||
dst = append(dst, p...)
|
|
||||||
}
|
|
||||||
|
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
|
||||||
for _, fc := range src.ResultFormatCodes {
|
|
||||||
dst = pgio.AppendInt16(dst, fc)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Bind) MarshalJSON() ([]byte, error) {
|
|
||||||
formattedParameters := make([]map[string]string, len(src.Parameters))
|
|
||||||
for i, p := range src.Parameters {
|
|
||||||
if p == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if src.ParameterFormatCodes[i] == 0 {
|
|
||||||
formattedParameters[i] = map[string]string{"text": string(p)}
|
|
||||||
} else {
|
|
||||||
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
DestinationPortal string
|
|
||||||
PreparedStatement string
|
|
||||||
ParameterFormatCodes []int16
|
|
||||||
Parameters []map[string]string
|
|
||||||
ResultFormatCodes []int16
|
|
||||||
}{
|
|
||||||
Type: "Bind",
|
|
||||||
DestinationPortal: src.DestinationPortal,
|
|
||||||
PreparedStatement: src.PreparedStatement,
|
|
||||||
ParameterFormatCodes: src.ParameterFormatCodes,
|
|
||||||
Parameters: formattedParameters,
|
|
||||||
ResultFormatCodes: src.ResultFormatCodes,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BindComplete struct{}
|
|
||||||
|
|
||||||
func (*BindComplete) Backend() {}
|
|
||||||
|
|
||||||
func (dst *BindComplete) Decode(src []byte) error {
|
|
||||||
if len(src) != 0 {
|
|
||||||
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *BindComplete) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, '2', 0, 0, 0, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *BindComplete) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
}{
|
|
||||||
Type: "BindComplete",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Close struct {
|
|
||||||
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*Close) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *Close) Decode(src []byte) error {
|
|
||||||
if len(src) < 2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Close"}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.ObjectType = src[0]
|
|
||||||
rp := 1
|
|
||||||
|
|
||||||
idx := bytes.IndexByte(src[rp:], 0)
|
|
||||||
if idx != len(src[rp:])-1 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Close"}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.Name = string(src[rp : len(src)-1])
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Close) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'C')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.ObjectType)
|
|
||||||
dst = append(dst, src.Name...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Close) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
ObjectType string
|
|
||||||
Name string
|
|
||||||
}{
|
|
||||||
Type: "Close",
|
|
||||||
ObjectType: string(src.ObjectType),
|
|
||||||
Name: src.Name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CloseComplete struct{}
|
|
||||||
|
|
||||||
func (*CloseComplete) Backend() {}
|
|
||||||
|
|
||||||
func (dst *CloseComplete) Decode(src []byte) error {
|
|
||||||
if len(src) != 0 {
|
|
||||||
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CloseComplete) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, '3', 0, 0, 0, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CloseComplete) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
}{
|
|
||||||
Type: "CloseComplete",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CommandComplete struct {
|
|
||||||
CommandTag string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*CommandComplete) Backend() {}
|
|
||||||
|
|
||||||
func (dst *CommandComplete) Decode(src []byte) error {
|
|
||||||
idx := bytes.IndexByte(src, 0)
|
|
||||||
if idx != len(src)-1 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "CommandComplete"}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.CommandTag = string(src[:idx])
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CommandComplete) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'C')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.CommandTag...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CommandComplete) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
CommandTag string
|
|
||||||
}{
|
|
||||||
Type: "CommandComplete",
|
|
||||||
CommandTag: src.CommandTag,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CopyBothResponse struct {
|
|
||||||
OverallFormat byte
|
|
||||||
ColumnFormatCodes []uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*CopyBothResponse) Backend() {}
|
|
||||||
|
|
||||||
func (dst *CopyBothResponse) Decode(src []byte) error {
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
if buf.Len() < 3 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
|
||||||
}
|
|
||||||
|
|
||||||
overallFormat := buf.Next(1)[0]
|
|
||||||
|
|
||||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
|
||||||
if buf.Len() != columnCount*2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
|
||||||
}
|
|
||||||
|
|
||||||
columnFormatCodes := make([]uint16, columnCount)
|
|
||||||
for i := 0; i < columnCount; i++ {
|
|
||||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
|
||||||
}
|
|
||||||
|
|
||||||
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'W')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
|
||||||
for _, fc := range src.ColumnFormatCodes {
|
|
||||||
dst = pgio.AppendUint16(dst, fc)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyBothResponse) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
ColumnFormatCodes []uint16
|
|
||||||
}{
|
|
||||||
Type: "CopyBothResponse",
|
|
||||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CopyData struct {
|
|
||||||
Data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*CopyData) Backend() {}
|
|
||||||
func (*CopyData) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *CopyData) Decode(src []byte) error {
|
|
||||||
dst.Data = src
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyData) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'd')
|
|
||||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
|
||||||
dst = append(dst, src.Data...)
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyData) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
Data string
|
|
||||||
}{
|
|
||||||
Type: "CopyData",
|
|
||||||
Data: hex.EncodeToString(src.Data),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CopyDone struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*CopyDone) Backend() {}
|
|
||||||
|
|
||||||
func (dst *CopyDone) Decode(src []byte) error {
|
|
||||||
if len(src) != 0 {
|
|
||||||
return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyDone) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, 'c', 0, 0, 0, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyDone) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
}{
|
|
||||||
Type: "CopyDone",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CopyFail struct {
|
|
||||||
Error string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*CopyFail) Frontend() {}
|
|
||||||
func (*CopyFail) Backend() {}
|
|
||||||
|
|
||||||
func (dst *CopyFail) Decode(src []byte) error {
|
|
||||||
idx := bytes.IndexByte(src, 0)
|
|
||||||
if idx != len(src)-1 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "CopyFail"}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.Error = string(src[:idx])
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyFail) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'C')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.Error...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyFail) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
Error string
|
|
||||||
}{
|
|
||||||
Type: "CopyFail",
|
|
||||||
Error: src.Error,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CopyInResponse struct {
|
|
||||||
OverallFormat byte
|
|
||||||
ColumnFormatCodes []uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*CopyInResponse) Backend() {}
|
|
||||||
|
|
||||||
func (dst *CopyInResponse) Decode(src []byte) error {
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
if buf.Len() < 3 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
|
||||||
}
|
|
||||||
|
|
||||||
overallFormat := buf.Next(1)[0]
|
|
||||||
|
|
||||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
|
||||||
if buf.Len() != columnCount*2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
|
||||||
}
|
|
||||||
|
|
||||||
columnFormatCodes := make([]uint16, columnCount)
|
|
||||||
for i := 0; i < columnCount; i++ {
|
|
||||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
|
||||||
}
|
|
||||||
|
|
||||||
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyInResponse) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'G')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
|
||||||
for _, fc := range src.ColumnFormatCodes {
|
|
||||||
dst = pgio.AppendUint16(dst, fc)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyInResponse) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
ColumnFormatCodes []uint16
|
|
||||||
}{
|
|
||||||
Type: "CopyInResponse",
|
|
||||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CopyOutResponse struct {
|
|
||||||
OverallFormat byte
|
|
||||||
ColumnFormatCodes []uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*CopyOutResponse) Backend() {}
|
|
||||||
|
|
||||||
func (dst *CopyOutResponse) Decode(src []byte) error {
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
if buf.Len() < 3 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
|
||||||
}
|
|
||||||
|
|
||||||
overallFormat := buf.Next(1)[0]
|
|
||||||
|
|
||||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
|
||||||
if buf.Len() != columnCount*2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
|
||||||
}
|
|
||||||
|
|
||||||
columnFormatCodes := make([]uint16, columnCount)
|
|
||||||
for i := 0; i < columnCount; i++ {
|
|
||||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
|
||||||
}
|
|
||||||
|
|
||||||
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'H')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
|
||||||
for _, fc := range src.ColumnFormatCodes {
|
|
||||||
dst = pgio.AppendUint16(dst, fc)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *CopyOutResponse) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
ColumnFormatCodes []uint16
|
|
||||||
}{
|
|
||||||
Type: "CopyOutResponse",
|
|
||||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DataRow struct {
|
|
||||||
Values [][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*DataRow) Backend() {}
|
|
||||||
|
|
||||||
func (dst *DataRow) Decode(src []byte) error {
|
|
||||||
if len(src) < 2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
|
||||||
}
|
|
||||||
rp := 0
|
|
||||||
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
|
|
||||||
rp += 2
|
|
||||||
|
|
||||||
// If the capacity of the values slice is too small OR substantially too
|
|
||||||
// large reallocate. This is too avoid one row with many columns from
|
|
||||||
// permanently allocating memory.
|
|
||||||
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
|
|
||||||
newCap := 32
|
|
||||||
if newCap < fieldCount {
|
|
||||||
newCap = fieldCount
|
|
||||||
}
|
|
||||||
dst.Values = make([][]byte, fieldCount, newCap)
|
|
||||||
} else {
|
|
||||||
dst.Values = dst.Values[:fieldCount]
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < fieldCount; i++ {
|
|
||||||
if len(src[rp:]) < 4 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
|
||||||
}
|
|
||||||
|
|
||||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
|
||||||
rp += 4
|
|
||||||
|
|
||||||
// null
|
|
||||||
if msgSize == -1 {
|
|
||||||
dst.Values[i] = nil
|
|
||||||
} else {
|
|
||||||
if len(src[rp:]) < msgSize {
|
|
||||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.Values[i] = src[rp : rp+msgSize]
|
|
||||||
rp += msgSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *DataRow) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'D')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
|
||||||
for _, v := range src.Values {
|
|
||||||
if v == nil {
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
dst = pgio.AppendInt32(dst, int32(len(v)))
|
|
||||||
dst = append(dst, v...)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *DataRow) MarshalJSON() ([]byte, error) {
|
|
||||||
formattedValues := make([]map[string]string, len(src.Values))
|
|
||||||
for i, v := range src.Values {
|
|
||||||
if v == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var hasNonPrintable bool
|
|
||||||
for _, b := range v {
|
|
||||||
if b < 32 {
|
|
||||||
hasNonPrintable = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasNonPrintable {
|
|
||||||
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
|
|
||||||
} else {
|
|
||||||
formattedValues[i] = map[string]string{"text": string(v)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
Values []map[string]string
|
|
||||||
}{
|
|
||||||
Type: "DataRow",
|
|
||||||
Values: formattedValues,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Describe struct {
|
|
||||||
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*Describe) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *Describe) Decode(src []byte) error {
|
|
||||||
if len(src) < 2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Describe"}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.ObjectType = src[0]
|
|
||||||
rp := 1
|
|
||||||
|
|
||||||
idx := bytes.IndexByte(src[rp:], 0)
|
|
||||||
if idx != len(src[rp:])-1 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Describe"}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.Name = string(src[rp : len(src)-1])
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Describe) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'D')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.ObjectType)
|
|
||||||
dst = append(dst, src.Name...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Describe) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
ObjectType string
|
|
||||||
Name string
|
|
||||||
}{
|
|
||||||
Type: "Describe",
|
|
||||||
ObjectType: string(src.ObjectType),
|
|
||||||
Name: src.Name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type EmptyQueryResponse struct{}
|
|
||||||
|
|
||||||
func (*EmptyQueryResponse) Backend() {}
|
|
||||||
|
|
||||||
func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
|
||||||
if len(src) != 0 {
|
|
||||||
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, 'I', 0, 0, 0, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
}{
|
|
||||||
Type: "EmptyQueryResponse",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"strconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ErrorResponse struct {
|
|
||||||
Severity string
|
|
||||||
Code string
|
|
||||||
Message string
|
|
||||||
Detail string
|
|
||||||
Hint string
|
|
||||||
Position int32
|
|
||||||
InternalPosition int32
|
|
||||||
InternalQuery string
|
|
||||||
Where string
|
|
||||||
SchemaName string
|
|
||||||
TableName string
|
|
||||||
ColumnName string
|
|
||||||
DataTypeName string
|
|
||||||
ConstraintName string
|
|
||||||
File string
|
|
||||||
Line int32
|
|
||||||
Routine string
|
|
||||||
|
|
||||||
UnknownFields map[byte]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*ErrorResponse) Backend() {}
|
|
||||||
|
|
||||||
func (dst *ErrorResponse) Decode(src []byte) error {
|
|
||||||
*dst = ErrorResponse{}
|
|
||||||
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
for {
|
|
||||||
k, err := buf.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if k == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
vb, err := buf.ReadBytes(0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
v := string(vb[:len(vb)-1])
|
|
||||||
|
|
||||||
switch k {
|
|
||||||
case 'S':
|
|
||||||
dst.Severity = v
|
|
||||||
case 'C':
|
|
||||||
dst.Code = v
|
|
||||||
case 'M':
|
|
||||||
dst.Message = v
|
|
||||||
case 'D':
|
|
||||||
dst.Detail = v
|
|
||||||
case 'H':
|
|
||||||
dst.Hint = v
|
|
||||||
case 'P':
|
|
||||||
s := v
|
|
||||||
n, _ := strconv.ParseInt(s, 10, 32)
|
|
||||||
dst.Position = int32(n)
|
|
||||||
case 'p':
|
|
||||||
s := v
|
|
||||||
n, _ := strconv.ParseInt(s, 10, 32)
|
|
||||||
dst.InternalPosition = int32(n)
|
|
||||||
case 'q':
|
|
||||||
dst.InternalQuery = v
|
|
||||||
case 'W':
|
|
||||||
dst.Where = v
|
|
||||||
case 's':
|
|
||||||
dst.SchemaName = v
|
|
||||||
case 't':
|
|
||||||
dst.TableName = v
|
|
||||||
case 'c':
|
|
||||||
dst.ColumnName = v
|
|
||||||
case 'd':
|
|
||||||
dst.DataTypeName = v
|
|
||||||
case 'n':
|
|
||||||
dst.ConstraintName = v
|
|
||||||
case 'F':
|
|
||||||
dst.File = v
|
|
||||||
case 'L':
|
|
||||||
s := v
|
|
||||||
n, _ := strconv.ParseInt(s, 10, 32)
|
|
||||||
dst.Line = int32(n)
|
|
||||||
case 'R':
|
|
||||||
dst.Routine = v
|
|
||||||
|
|
||||||
default:
|
|
||||||
if dst.UnknownFields == nil {
|
|
||||||
dst.UnknownFields = make(map[byte]string)
|
|
||||||
}
|
|
||||||
dst.UnknownFields[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *ErrorResponse) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, src.marshalBinary('E')...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
|
||||||
var bigEndian BigEndianBuf
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
|
|
||||||
buf.WriteByte(typeByte)
|
|
||||||
buf.Write(bigEndian.Uint32(0))
|
|
||||||
|
|
||||||
if src.Severity != "" {
|
|
||||||
buf.WriteByte('S')
|
|
||||||
buf.WriteString(src.Severity)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.Code != "" {
|
|
||||||
buf.WriteByte('C')
|
|
||||||
buf.WriteString(src.Code)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.Message != "" {
|
|
||||||
buf.WriteByte('M')
|
|
||||||
buf.WriteString(src.Message)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.Detail != "" {
|
|
||||||
buf.WriteByte('D')
|
|
||||||
buf.WriteString(src.Detail)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.Hint != "" {
|
|
||||||
buf.WriteByte('H')
|
|
||||||
buf.WriteString(src.Hint)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.Position != 0 {
|
|
||||||
buf.WriteByte('P')
|
|
||||||
buf.WriteString(strconv.Itoa(int(src.Position)))
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.InternalPosition != 0 {
|
|
||||||
buf.WriteByte('p')
|
|
||||||
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.InternalQuery != "" {
|
|
||||||
buf.WriteByte('q')
|
|
||||||
buf.WriteString(src.InternalQuery)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.Where != "" {
|
|
||||||
buf.WriteByte('W')
|
|
||||||
buf.WriteString(src.Where)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.SchemaName != "" {
|
|
||||||
buf.WriteByte('s')
|
|
||||||
buf.WriteString(src.SchemaName)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.TableName != "" {
|
|
||||||
buf.WriteByte('t')
|
|
||||||
buf.WriteString(src.TableName)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.ColumnName != "" {
|
|
||||||
buf.WriteByte('c')
|
|
||||||
buf.WriteString(src.ColumnName)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.DataTypeName != "" {
|
|
||||||
buf.WriteByte('d')
|
|
||||||
buf.WriteString(src.DataTypeName)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.ConstraintName != "" {
|
|
||||||
buf.WriteByte('n')
|
|
||||||
buf.WriteString(src.ConstraintName)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.File != "" {
|
|
||||||
buf.WriteByte('F')
|
|
||||||
buf.WriteString(src.File)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.Line != 0 {
|
|
||||||
buf.WriteByte('L')
|
|
||||||
buf.WriteString(strconv.Itoa(int(src.Line)))
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
if src.Routine != "" {
|
|
||||||
buf.WriteByte('R')
|
|
||||||
buf.WriteString(src.Routine)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range src.UnknownFields {
|
|
||||||
buf.WriteByte(k)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
buf.WriteString(v)
|
|
||||||
buf.WriteByte(0)
|
|
||||||
}
|
|
||||||
buf.WriteByte(0)
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
|
||||||
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Execute struct {
|
|
||||||
Portal string
|
|
||||||
MaxRows uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*Execute) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *Execute) Decode(src []byte) error {
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
b, err := buf.ReadBytes(0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
dst.Portal = string(b[:len(b)-1])
|
|
||||||
|
|
||||||
if buf.Len() < 4 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Execute"}
|
|
||||||
}
|
|
||||||
dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Execute) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'E')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.Portal...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint32(dst, src.MaxRows)
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Execute) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
Portal string
|
|
||||||
MaxRows uint32
|
|
||||||
}{
|
|
||||||
Type: "Execute",
|
|
||||||
Portal: src.Portal,
|
|
||||||
MaxRows: src.MaxRows,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Flush struct{}
|
|
||||||
|
|
||||||
func (*Flush) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *Flush) Decode(src []byte) error {
|
|
||||||
if len(src) != 0 {
|
|
||||||
return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Flush) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, 'H', 0, 0, 0, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Flush) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
}{
|
|
||||||
Type: "Flush",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/chunkreader"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Frontend struct {
|
|
||||||
cr *chunkreader.ChunkReader
|
|
||||||
w io.Writer
|
|
||||||
|
|
||||||
// Backend message flyweights
|
|
||||||
authentication Authentication
|
|
||||||
backendKeyData BackendKeyData
|
|
||||||
bindComplete BindComplete
|
|
||||||
closeComplete CloseComplete
|
|
||||||
commandComplete CommandComplete
|
|
||||||
copyBothResponse CopyBothResponse
|
|
||||||
copyData CopyData
|
|
||||||
copyInResponse CopyInResponse
|
|
||||||
copyOutResponse CopyOutResponse
|
|
||||||
copyDone CopyDone
|
|
||||||
copyFail CopyFail
|
|
||||||
dataRow DataRow
|
|
||||||
emptyQueryResponse EmptyQueryResponse
|
|
||||||
errorResponse ErrorResponse
|
|
||||||
functionCallResponse FunctionCallResponse
|
|
||||||
noData NoData
|
|
||||||
noticeResponse NoticeResponse
|
|
||||||
notificationResponse NotificationResponse
|
|
||||||
parameterDescription ParameterDescription
|
|
||||||
parameterStatus ParameterStatus
|
|
||||||
parseComplete ParseComplete
|
|
||||||
readyForQuery ReadyForQuery
|
|
||||||
rowDescription RowDescription
|
|
||||||
|
|
||||||
bodyLen int
|
|
||||||
msgType byte
|
|
||||||
partialMsg bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
|
|
||||||
cr := chunkreader.NewChunkReader(r)
|
|
||||||
return &Frontend{cr: cr, w: w}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Frontend) Send(msg FrontendMessage) error {
|
|
||||||
_, err := b.w.Write(msg.Encode(nil))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Frontend) Receive() (BackendMessage, error) {
|
|
||||||
if !b.partialMsg {
|
|
||||||
header, err := b.cr.Next(5)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
b.msgType = header[0]
|
|
||||||
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
|
||||||
b.partialMsg = true
|
|
||||||
}
|
|
||||||
|
|
||||||
var msg BackendMessage
|
|
||||||
switch b.msgType {
|
|
||||||
case '1':
|
|
||||||
msg = &b.parseComplete
|
|
||||||
case '2':
|
|
||||||
msg = &b.bindComplete
|
|
||||||
case '3':
|
|
||||||
msg = &b.closeComplete
|
|
||||||
case 'A':
|
|
||||||
msg = &b.notificationResponse
|
|
||||||
case 'c':
|
|
||||||
msg = &b.copyDone
|
|
||||||
case 'C':
|
|
||||||
msg = &b.commandComplete
|
|
||||||
case 'd':
|
|
||||||
msg = &b.copyData
|
|
||||||
case 'D':
|
|
||||||
msg = &b.dataRow
|
|
||||||
case 'E':
|
|
||||||
msg = &b.errorResponse
|
|
||||||
case 'f':
|
|
||||||
msg = &b.copyFail
|
|
||||||
case 'G':
|
|
||||||
msg = &b.copyInResponse
|
|
||||||
case 'H':
|
|
||||||
msg = &b.copyOutResponse
|
|
||||||
case 'I':
|
|
||||||
msg = &b.emptyQueryResponse
|
|
||||||
case 'K':
|
|
||||||
msg = &b.backendKeyData
|
|
||||||
case 'n':
|
|
||||||
msg = &b.noData
|
|
||||||
case 'N':
|
|
||||||
msg = &b.noticeResponse
|
|
||||||
case 'R':
|
|
||||||
msg = &b.authentication
|
|
||||||
case 'S':
|
|
||||||
msg = &b.parameterStatus
|
|
||||||
case 't':
|
|
||||||
msg = &b.parameterDescription
|
|
||||||
case 'T':
|
|
||||||
msg = &b.rowDescription
|
|
||||||
case 'V':
|
|
||||||
msg = &b.functionCallResponse
|
|
||||||
case 'W':
|
|
||||||
msg = &b.copyBothResponse
|
|
||||||
case 'Z':
|
|
||||||
msg = &b.readyForQuery
|
|
||||||
default:
|
|
||||||
return nil, errors.Errorf("unknown message type: %c", b.msgType)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgBody, err := b.cr.Next(b.bodyLen)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
b.partialMsg = false
|
|
||||||
|
|
||||||
err = msg.Decode(msgBody)
|
|
||||||
return msg, err
|
|
||||||
}
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
package pgproto3_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgproto3"
|
|
||||||
)
|
|
||||||
|
|
||||||
type interruptReader struct {
|
|
||||||
chunks [][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ir *interruptReader) Read(p []byte) (n int, err error) {
|
|
||||||
if len(ir.chunks) == 0 {
|
|
||||||
return 0, errors.New("no data")
|
|
||||||
}
|
|
||||||
|
|
||||||
n = copy(p, ir.chunks[0])
|
|
||||||
if n != len(ir.chunks[0]) {
|
|
||||||
panic("this test reader doesn't support partial reads of chunks")
|
|
||||||
}
|
|
||||||
|
|
||||||
ir.chunks = ir.chunks[1:]
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ir *interruptReader) push(p []byte) {
|
|
||||||
ir.chunks = append(ir.chunks, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFrontendReceiveInterrupted(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
server := &interruptReader{}
|
|
||||||
server.push([]byte{'Z', 0, 0, 0, 5})
|
|
||||||
|
|
||||||
frontend, err := pgproto3.NewFrontend(server, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := frontend.Receive()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected err")
|
|
||||||
}
|
|
||||||
if msg != nil {
|
|
||||||
t.Fatalf("did not expect msg, but %v", msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
server.push([]byte{'I'})
|
|
||||||
|
|
||||||
msg, err = frontend.Receive()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' {
|
|
||||||
t.Fatalf("unexpected msg: %v", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type FunctionCallResponse struct {
|
|
||||||
Result []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*FunctionCallResponse) Backend() {}
|
|
||||||
|
|
||||||
func (dst *FunctionCallResponse) Decode(src []byte) error {
|
|
||||||
if len(src) < 4 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
|
||||||
}
|
|
||||||
rp := 0
|
|
||||||
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
|
|
||||||
rp += 4
|
|
||||||
|
|
||||||
if resultSize == -1 {
|
|
||||||
dst.Result = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(src[rp:]) != resultSize {
|
|
||||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.Result = src[rp:]
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'V')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
if src.Result == nil {
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
} else {
|
|
||||||
dst = pgio.AppendInt32(dst, int32(len(src.Result)))
|
|
||||||
dst = append(dst, src.Result...)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) {
|
|
||||||
var formattedValue map[string]string
|
|
||||||
var hasNonPrintable bool
|
|
||||||
for _, b := range src.Result {
|
|
||||||
if b < 32 {
|
|
||||||
hasNonPrintable = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasNonPrintable {
|
|
||||||
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
|
|
||||||
} else {
|
|
||||||
formattedValue = map[string]string{"text": string(src.Result)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
Result map[string]string
|
|
||||||
}{
|
|
||||||
Type: "FunctionCallResponse",
|
|
||||||
Result: formattedValue,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type NoData struct{}
|
|
||||||
|
|
||||||
func (*NoData) Backend() {}
|
|
||||||
|
|
||||||
func (dst *NoData) Decode(src []byte) error {
|
|
||||||
if len(src) != 0 {
|
|
||||||
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *NoData) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, 'n', 0, 0, 0, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *NoData) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
}{
|
|
||||||
Type: "NoData",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
type NoticeResponse ErrorResponse
|
|
||||||
|
|
||||||
func (*NoticeResponse) Backend() {}
|
|
||||||
|
|
||||||
func (dst *NoticeResponse) Decode(src []byte) error {
|
|
||||||
return (*ErrorResponse)(dst).Decode(src)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *NoticeResponse) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
|
|
||||||
}
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type NotificationResponse struct {
|
|
||||||
PID uint32
|
|
||||||
Channel string
|
|
||||||
Payload string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*NotificationResponse) Backend() {}
|
|
||||||
|
|
||||||
func (dst *NotificationResponse) Decode(src []byte) error {
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
pid := binary.BigEndian.Uint32(buf.Next(4))
|
|
||||||
|
|
||||||
b, err := buf.ReadBytes(0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
channel := string(b[:len(b)-1])
|
|
||||||
|
|
||||||
b, err = buf.ReadBytes(0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
payload := string(b[:len(b)-1])
|
|
||||||
|
|
||||||
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *NotificationResponse) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'A')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.Channel...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
dst = append(dst, src.Payload...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *NotificationResponse) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
PID uint32
|
|
||||||
Channel string
|
|
||||||
Payload string
|
|
||||||
}{
|
|
||||||
Type: "NotificationResponse",
|
|
||||||
PID: src.PID,
|
|
||||||
Channel: src.Channel,
|
|
||||||
Payload: src.Payload,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ParameterDescription struct {
|
|
||||||
ParameterOIDs []uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*ParameterDescription) Backend() {}
|
|
||||||
|
|
||||||
func (dst *ParameterDescription) Decode(src []byte) error {
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
if buf.Len() < 2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reported parameter count will be incorrect when number of args is greater than uint16
|
|
||||||
buf.Next(2)
|
|
||||||
// Instead infer parameter count by remaining size of message
|
|
||||||
parameterCount := buf.Len() / 4
|
|
||||||
|
|
||||||
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
|
|
||||||
|
|
||||||
for i := 0; i < parameterCount; i++ {
|
|
||||||
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *ParameterDescription) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 't')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
|
||||||
for _, oid := range src.ParameterOIDs {
|
|
||||||
dst = pgio.AppendUint32(dst, oid)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *ParameterDescription) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
ParameterOIDs []uint32
|
|
||||||
}{
|
|
||||||
Type: "ParameterDescription",
|
|
||||||
ParameterOIDs: src.ParameterOIDs,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ParameterStatus struct {
|
|
||||||
Name string
|
|
||||||
Value string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*ParameterStatus) Backend() {}
|
|
||||||
|
|
||||||
func (dst *ParameterStatus) Decode(src []byte) error {
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
b, err := buf.ReadBytes(0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
name := string(b[:len(b)-1])
|
|
||||||
|
|
||||||
b, err = buf.ReadBytes(0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
value := string(b[:len(b)-1])
|
|
||||||
|
|
||||||
*dst = ParameterStatus{Name: name, Value: value}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *ParameterStatus) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'S')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.Name...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
dst = append(dst, src.Value...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *ParameterStatus) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
Name string
|
|
||||||
Value string
|
|
||||||
}{
|
|
||||||
Type: "ParameterStatus",
|
|
||||||
Name: ps.Name,
|
|
||||||
Value: ps.Value,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Parse struct {
|
|
||||||
Name string
|
|
||||||
Query string
|
|
||||||
ParameterOIDs []uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*Parse) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *Parse) Decode(src []byte) error {
|
|
||||||
*dst = Parse{}
|
|
||||||
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
b, err := buf.ReadBytes(0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
dst.Name = string(b[:len(b)-1])
|
|
||||||
|
|
||||||
b, err = buf.ReadBytes(0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
dst.Query = string(b[:len(b)-1])
|
|
||||||
|
|
||||||
if buf.Len() < 2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Parse"}
|
|
||||||
}
|
|
||||||
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
|
||||||
|
|
||||||
for i := 0; i < parameterOIDCount; i++ {
|
|
||||||
if buf.Len() < 4 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Parse"}
|
|
||||||
}
|
|
||||||
dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Parse) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'P')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = append(dst, src.Name...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
dst = append(dst, src.Query...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
|
||||||
for _, oid := range src.ParameterOIDs {
|
|
||||||
dst = pgio.AppendUint32(dst, oid)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Parse) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
Name string
|
|
||||||
Query string
|
|
||||||
ParameterOIDs []uint32
|
|
||||||
}{
|
|
||||||
Type: "Parse",
|
|
||||||
Name: src.Name,
|
|
||||||
Query: src.Query,
|
|
||||||
ParameterOIDs: src.ParameterOIDs,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ParseComplete struct{}
|
|
||||||
|
|
||||||
func (*ParseComplete) Backend() {}
|
|
||||||
|
|
||||||
func (dst *ParseComplete) Decode(src []byte) error {
|
|
||||||
if len(src) != 0 {
|
|
||||||
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *ParseComplete) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, '1', 0, 0, 0, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *ParseComplete) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
}{
|
|
||||||
Type: "ParseComplete",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PasswordMessage struct {
|
|
||||||
Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*PasswordMessage) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *PasswordMessage) Decode(src []byte) error {
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
b, err := buf.ReadBytes(0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
dst.Password = string(b[:len(b)-1])
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *PasswordMessage) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'p')
|
|
||||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
|
|
||||||
|
|
||||||
dst = append(dst, src.Password...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *PasswordMessage) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
Password string
|
|
||||||
}{
|
|
||||||
Type: "PasswordMessage",
|
|
||||||
Password: src.Password,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// Message is the interface implemented by an object that can decode and encode
|
|
||||||
// a particular PostgreSQL message.
|
|
||||||
type Message interface {
|
|
||||||
// Decode is allowed and expected to retain a reference to data after
|
|
||||||
// returning (unlike encoding.BinaryUnmarshaler).
|
|
||||||
Decode(data []byte) error
|
|
||||||
|
|
||||||
// Encode appends itself to dst and returns the new buffer.
|
|
||||||
Encode(dst []byte) []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type FrontendMessage interface {
|
|
||||||
Message
|
|
||||||
Frontend() // no-op method to distinguish frontend from backend methods
|
|
||||||
}
|
|
||||||
|
|
||||||
type BackendMessage interface {
|
|
||||||
Message
|
|
||||||
Backend() // no-op method to distinguish frontend from backend methods
|
|
||||||
}
|
|
||||||
|
|
||||||
type invalidMessageLenErr struct {
|
|
||||||
messageType string
|
|
||||||
expectedLen int
|
|
||||||
actualLen int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *invalidMessageLenErr) Error() string {
|
|
||||||
return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
type invalidMessageFormatErr struct {
|
|
||||||
messageType string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *invalidMessageFormatErr) Error() string {
|
|
||||||
return fmt.Sprintf("%s body is invalid", e.messageType)
|
|
||||||
}
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Query struct {
|
|
||||||
String string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*Query) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *Query) Decode(src []byte) error {
|
|
||||||
i := bytes.IndexByte(src, 0)
|
|
||||||
if i != len(src)-1 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "Query"}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.String = string(src[:i])
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Query) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'Q')
|
|
||||||
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
|
|
||||||
|
|
||||||
dst = append(dst, src.String...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Query) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
String string
|
|
||||||
}{
|
|
||||||
Type: "Query",
|
|
||||||
String: src.String,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ReadyForQuery struct {
|
|
||||||
TxStatus byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*ReadyForQuery) Backend() {}
|
|
||||||
|
|
||||||
func (dst *ReadyForQuery) Decode(src []byte) error {
|
|
||||||
if len(src) != 1 {
|
|
||||||
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.TxStatus = src[0]
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *ReadyForQuery) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *ReadyForQuery) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
TxStatus string
|
|
||||||
}{
|
|
||||||
Type: "ReadyForQuery",
|
|
||||||
TxStatus: string(src.TxStatus),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
TextFormat = 0
|
|
||||||
BinaryFormat = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
type FieldDescription struct {
|
|
||||||
Name string
|
|
||||||
TableOID uint32
|
|
||||||
TableAttributeNumber uint16
|
|
||||||
DataTypeOID uint32
|
|
||||||
DataTypeSize int16
|
|
||||||
TypeModifier int32
|
|
||||||
Format int16
|
|
||||||
}
|
|
||||||
|
|
||||||
type RowDescription struct {
|
|
||||||
Fields []FieldDescription
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*RowDescription) Backend() {}
|
|
||||||
|
|
||||||
func (dst *RowDescription) Decode(src []byte) error {
|
|
||||||
buf := bytes.NewBuffer(src)
|
|
||||||
|
|
||||||
if buf.Len() < 2 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
|
||||||
}
|
|
||||||
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
|
||||||
|
|
||||||
dst.Fields = dst.Fields[0:0]
|
|
||||||
|
|
||||||
for i := 0; i < fieldCount; i++ {
|
|
||||||
var fd FieldDescription
|
|
||||||
bName, err := buf.ReadBytes(0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
fd.Name = string(bName[:len(bName)-1])
|
|
||||||
|
|
||||||
// Since buf.Next() doesn't return an error if we hit the end of the buffer
|
|
||||||
// check Len ahead of time
|
|
||||||
if buf.Len() < 18 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
|
||||||
}
|
|
||||||
|
|
||||||
fd.TableOID = binary.BigEndian.Uint32(buf.Next(4))
|
|
||||||
fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2))
|
|
||||||
fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4))
|
|
||||||
fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2)))
|
|
||||||
fd.TypeModifier = int32(binary.BigEndian.Uint32(buf.Next(4)))
|
|
||||||
fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2)))
|
|
||||||
|
|
||||||
dst.Fields = append(dst.Fields, fd)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *RowDescription) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'T')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
|
|
||||||
for _, fd := range src.Fields {
|
|
||||||
dst = append(dst, fd.Name...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint32(dst, fd.TableOID)
|
|
||||||
dst = pgio.AppendUint16(dst, fd.TableAttributeNumber)
|
|
||||||
dst = pgio.AppendUint32(dst, fd.DataTypeOID)
|
|
||||||
dst = pgio.AppendInt16(dst, fd.DataTypeSize)
|
|
||||||
dst = pgio.AppendInt32(dst, fd.TypeModifier)
|
|
||||||
dst = pgio.AppendInt16(dst, fd.Format)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *RowDescription) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
Fields []FieldDescription
|
|
||||||
}{
|
|
||||||
Type: "RowDescription",
|
|
||||||
Fields: src.Fields,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ProtocolVersionNumber = 196608 // 3.0
|
|
||||||
sslRequestNumber = 80877103
|
|
||||||
)
|
|
||||||
|
|
||||||
type StartupMessage struct {
|
|
||||||
ProtocolVersion uint32
|
|
||||||
Parameters map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*StartupMessage) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *StartupMessage) Decode(src []byte) error {
|
|
||||||
if len(src) < 4 {
|
|
||||||
return errors.Errorf("startup message too short")
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
|
||||||
rp := 4
|
|
||||||
|
|
||||||
if dst.ProtocolVersion == sslRequestNumber {
|
|
||||||
return errors.Errorf("can't handle ssl connection request")
|
|
||||||
}
|
|
||||||
|
|
||||||
if dst.ProtocolVersion != ProtocolVersionNumber {
|
|
||||||
return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.Parameters = make(map[string]string)
|
|
||||||
for {
|
|
||||||
idx := bytes.IndexByte(src[rp:], 0)
|
|
||||||
if idx < 0 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
|
||||||
}
|
|
||||||
key := string(src[rp : rp+idx])
|
|
||||||
rp += idx + 1
|
|
||||||
|
|
||||||
idx = bytes.IndexByte(src[rp:], 0)
|
|
||||||
if idx < 0 {
|
|
||||||
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
|
||||||
}
|
|
||||||
value := string(src[rp : rp+idx])
|
|
||||||
rp += idx + 1
|
|
||||||
|
|
||||||
dst.Parameters[key] = value
|
|
||||||
|
|
||||||
if len(src[rp:]) == 1 {
|
|
||||||
if src[rp] != 0 {
|
|
||||||
return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *StartupMessage) Encode(dst []byte) []byte {
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
|
|
||||||
dst = pgio.AppendUint32(dst, src.ProtocolVersion)
|
|
||||||
for k, v := range src.Parameters {
|
|
||||||
dst = append(dst, k...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
dst = append(dst, v...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
}
|
|
||||||
dst = append(dst, 0)
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *StartupMessage) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
ProtocolVersion uint32
|
|
||||||
Parameters map[string]string
|
|
||||||
}{
|
|
||||||
Type: "StartupMessage",
|
|
||||||
ProtocolVersion: src.ProtocolVersion,
|
|
||||||
Parameters: src.Parameters,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Sync struct{}
|
|
||||||
|
|
||||||
func (*Sync) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *Sync) Decode(src []byte) error {
|
|
||||||
if len(src) != 0 {
|
|
||||||
return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Sync) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, 'S', 0, 0, 0, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Sync) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
}{
|
|
||||||
Type: "Sync",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Terminate struct{}
|
|
||||||
|
|
||||||
func (*Terminate) Frontend() {}
|
|
||||||
|
|
||||||
func (dst *Terminate) Decode(src []byte) error {
|
|
||||||
if len(src) != 0 {
|
|
||||||
return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Terminate) Encode(dst []byte) []byte {
|
|
||||||
return append(dst, 'X', 0, 0, 0, 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (src *Terminate) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
Type string
|
|
||||||
}{
|
|
||||||
Type: "Terminate",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -5,7 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -3,7 +3,7 @@ package pgtype
|
|||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -6,7 +6,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -6,7 +6,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Hstore represents an hstore column that can be null or have null values
|
// Hstore represents an hstore column that can be null or have null values
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -6,7 +6,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -7,7 +7,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -3,7 +3,7 @@ package pgtype
|
|||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -7,7 +7,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -3,7 +3,7 @@ package pgtype
|
|||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user