diff --git a/README.md b/README.md index 7ba9eaf8..d1c2dac4 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ database/sql compatibility layer for pgx. pgx can be used as a normal database/s Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver. -## github.com/jackc/pgx/pgproto3 +## github.com/jackc/pgproto3 pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling. diff --git a/batch.go b/batch.go index ca77dd6d..f26a398e 100644 --- a/batch.go +++ b/batch.go @@ -3,7 +3,7 @@ package pgx import ( "context" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/jackc/pgx/pgtype" "github.com/pkg/errors" ) diff --git a/batch_test.go b/batch_test.go index 7fec6025..1c37093a 100644 --- a/batch_test.go +++ b/batch_test.go @@ -5,8 +5,8 @@ import ( "os" "testing" + "github.com/jackc/pgconn" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgconn" "github.com/jackc/pgx/pgtype" ) diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go deleted file mode 100644 index f8d437b2..00000000 --- a/chunkreader/chunkreader.go +++ /dev/null @@ -1,89 +0,0 @@ -package chunkreader - -import ( - "io" -) - -type ChunkReader struct { - r io.Reader - - buf []byte - rp, wp int // buf read position and write position - - options Options -} - -type Options struct { - MinBufLen int // Minimum buffer length -} - -func NewChunkReader(r io.Reader) *ChunkReader { - cr, err := NewChunkReaderEx(r, Options{}) - if err != nil { - panic("default options can't be bad") - } - - return cr -} - -func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { - if options.MinBufLen == 0 { - options.MinBufLen = 4096 - } - - return &ChunkReader{ - r: r, - buf: make([]byte, options.MinBufLen), - options: options, - }, nil -} - -// Next returns buf filled with the next n bytes. If an error occurs, buf will -// be nil. -func (r *ChunkReader) Next(n int) (buf []byte, err error) { - // n bytes already in buf - if (r.wp - r.rp) >= n { - buf = r.buf[r.rp : r.rp+n] - r.rp += n - return buf, err - } - - // available space in buf is less than n - if len(r.buf) < n { - r.copyBufContents(r.newBuf(n)) - } - - // buf is large enough, but need to shift filled area to start to make enough contiguous space - minReadCount := n - (r.wp - r.rp) - if (len(r.buf) - r.wp) < minReadCount { - newBuf := r.newBuf(n) - r.copyBufContents(newBuf) - } - - if err := r.appendAtLeast(minReadCount); err != nil { - return nil, err - } - - buf = r.buf[r.rp : r.rp+n] - r.rp += n - return buf, nil -} - -func (r *ChunkReader) appendAtLeast(fillLen int) error { - n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) - r.wp += n - return err -} - -func (r *ChunkReader) newBuf(size int) []byte { - if size < r.options.MinBufLen { - size = r.options.MinBufLen - } - return make([]byte, size) -} - -func (r *ChunkReader) copyBufContents(dest []byte) { - r.wp = copy(dest, r.buf[r.rp:r.wp]) - r.rp = 0 - r.buf = dest -} diff --git a/chunkreader/chunkreader_test.go b/chunkreader/chunkreader_test.go deleted file mode 100644 index 3be07e3c..00000000 --- a/chunkreader/chunkreader_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package chunkreader - -import ( - "bytes" - "testing" -) - -func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { - server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) - if err != nil { - t.Fatal(err) - } - - src := []byte{1, 2, 3, 4} - server.Write(src) - - n1, err := r.Next(2) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:2]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1) - } - - n2, err := r.Next(2) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n2, src[2:4]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2) - } - - if bytes.Compare(r.buf, src) != 0 { - t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) - } - if r.rp != 4 { - t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp) - } - if r.wp != 4 { - t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp) - } -} - -func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { - server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) - if err != nil { - t.Fatal(err) - } - - src := []byte{1, 2, 3, 4, 5, 6, 7, 8} - server.Write(src) - - n1, err := r.Next(5) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:5]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) - } - if len(r.buf) != 5 { - t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf)) - } -} - -func TestChunkReaderDoesNotReuseBuf(t *testing.T) { - server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) - if err != nil { - t.Fatal(err) - } - - src := []byte{1, 2, 3, 4, 5, 6, 7, 8} - server.Write(src) - - n1, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:4]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) - } - - n2, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n2, src[4:8]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) - } - - if bytes.Compare(n1, src[0:4]) != 0 { - t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1) - } -} diff --git a/conn.go b/conn.go index f5cd9d64..e7d09828 100644 --- a/conn.go +++ b/conn.go @@ -12,8 +12,8 @@ import ( "github.com/pkg/errors" - "github.com/jackc/pgx/pgconn" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3" "github.com/jackc/pgx/pgtype" ) diff --git a/conn_pool.go b/conn_pool.go index fc7457ee..471a505c 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -7,7 +7,7 @@ import ( "github.com/pkg/errors" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/jackc/pgx/pgtype" ) diff --git a/conn_pool_test.go b/conn_pool_test.go index f20c6010..4d1f2aaf 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -11,8 +11,8 @@ import ( "github.com/pkg/errors" + "github.com/jackc/pgconn" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgconn" ) func createConnPool(t *testing.T, maxConnections int) *pgx.ConnPool { diff --git a/conn_test.go b/conn_test.go index 0df63bca..40074456 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" + "github.com/jackc/pgconn" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgconn" "github.com/jackc/pgx/pgtype" "github.com/stretchr/testify/require" ) diff --git a/copy_from.go b/copy_from.go index 9116f3a0..34a28dff 100644 --- a/copy_from.go +++ b/copy_from.go @@ -6,7 +6,7 @@ import ( "fmt" "io" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/copy_from_test.go b/copy_from_test.go index 891da2d6..8809501f 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" + "github.com/jackc/pgconn" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgconn" "github.com/pkg/errors" ) diff --git a/fastpath.go b/fastpath.go index a6a4de8a..6ac81b2c 100644 --- a/fastpath.go +++ b/fastpath.go @@ -3,8 +3,8 @@ package pgx import ( "encoding/binary" - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgio" + "github.com/jackc/pgproto3" "github.com/jackc/pgx/pgtype" ) diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..d8fde0f5 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/jackc/pgx + +go 1.12 + +require ( + github.com/cockroachdb/apd v1.1.0 + github.com/jackc/pgconn v0.0.0-20190330221323-ed7d91dc9873 + github.com/jackc/pgio v1.0.0 + github.com/jackc/pgproto3 v1.0.0 + github.com/pkg/errors v0.8.1 + github.com/satori/go.uuid v1.2.0 + github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 + github.com/stretchr/testify v1.3.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..11affd8e --- /dev/null +++ b/go.sum @@ -0,0 +1,25 @@ +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/pgconn v0.0.0-20190330221323-ed7d91dc9873 h1:M68R77AKFS7dub7R7WgJ9D6yiNuExOYhBuGtazlbr10= +github.com/jackc/pgconn v0.0.0-20190330221323-ed7d91dc9873/go.mod h1:8Bzf8vzi/ZpcgLgrq8IUHjZX4ZU+Hf6N6/AJ85+fDeE= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.0.0 h1:25tUmlES7eyD96oYaUHc1dLOFbgcJtFzCdnOOoqmA1I= +github.com/jackc/pgproto3 v1.0.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/helper_test.go b/helper_test.go index b181ef31..1cbfd9f6 100644 --- a/helper_test.go +++ b/helper_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" + "github.com/jackc/pgconn" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgconn" "github.com/stretchr/testify/require" ) diff --git a/large_objects_test.go b/large_objects_test.go index e6279822..a1842918 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -6,8 +6,8 @@ import ( "os" "testing" + "github.com/jackc/pgconn" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgconn" ) func TestLargeObjects(t *testing.T) { diff --git a/messages.go b/messages.go index cd504bfc..e6496373 100644 --- a/messages.go +++ b/messages.go @@ -6,7 +6,7 @@ import ( "reflect" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/jackc/pgx/pgtype" ) diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go deleted file mode 100644 index d2576324..00000000 --- a/pgconn/benchmark_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package pgconn_test - -import ( - "context" - "os" - "testing" - - "github.com/jackc/pgx/pgconn" - "github.com/stretchr/testify/require" -) - -func BenchmarkConnect(b *testing.B) { - benchmarks := []struct { - name string - env string - }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, - } - - for _, bm := range benchmarks { - b.Run(bm.name, func(b *testing.B) { - connString := os.Getenv(bm.env) - if connString == "" { - b.Skipf("Skipping due to missing environment variable %v", bm.env) - } - - for i := 0; i < b.N; i++ { - conn, err := pgconn.Connect(context.Background(), connString) - require.Nil(b, err) - - err = conn.Close(context.Background()) - require.Nil(b, err) - } - }) - } -} - -func BenchmarkExec(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(b, err) - defer closeConn(b, conn) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() - require.Nil(b, err) - } -} - -func BenchmarkExecPossibleToCancel(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(b, err) - defer closeConn(b, conn) - - b.ResetTimer() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - for i := 0; i < b.N; i++ { - _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() - require.Nil(b, err) - } -} - -func BenchmarkExecPrepared(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(b, err) - defer closeConn(b, conn) - - _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) - require.Nil(b, err) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).Read() - require.Nil(b, result.Err) - } -} - -func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(b, err) - defer closeConn(b, conn) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) - require.Nil(b, err) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() - require.Nil(b, result.Err) - } -} diff --git a/pgconn/config.go b/pgconn/config.go deleted file mode 100644 index fec1fedf..00000000 --- a/pgconn/config.go +++ /dev/null @@ -1,501 +0,0 @@ -package pgconn - -import ( - "context" - "crypto/tls" - "crypto/x509" - "fmt" - "io/ioutil" - "math" - "net" - "net/url" - "os" - "os/user" - "path/filepath" - "regexp" - "strconv" - "strings" - "time" - - "github.com/jackc/pgx/pgpassfile" - "github.com/pkg/errors" -) - -type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error - -// Config is the settings used to establish a connection to a PostgreSQL server. -type Config struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 - Database string - User string - Password string - TLSConfig *tls.Config // nil disables TLS - DialFunc DialFunc // e.g. net.Dialer.DialContext - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) - - Fallbacks []*FallbackConfig - - // AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that - // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This - // allows implementing high availability behavior such as libpq does with target_session_attrs. - AfterConnectFunc AfterConnectFunc - - // OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context - // is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a - // query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire - // protocol do not support this cancellation method. - // - // It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be - // called whether it was successful or not. If an error occurs the connection should be closed. The connection must be - // in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read - // the connection until a ready for query message is received. - OnContextCancel func(*ContextCancel) - - // OnNotice is a callback function called when a notice response is received. - OnNotice NoticeHandler - - // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. - OnNotification NotificationHandler -} - -// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a -// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. -type FallbackConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 - TLSConfig *tls.Config // nil disables TLS -} - -// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with -// net.Dial. -func NetworkAddress(host string, port uint16) (network, address string) { - if strings.HasPrefix(host, "/") { - network = "unix" - address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) - } else { - network = "tcp" - address = fmt.Sprintf("%s:%d", host, port) - } - return network, address -} - -// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same -// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. -// It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the -// .pgpass file. -// -// # Example DSN -// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca -// -// # Example URL -// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca -// -// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated -// values that will be tried in order. This can be used as part of a high availability system. See -// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. -// -// # Example URL -// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb -// -// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed -// via database URL or DSN: -// -// PGHOST -// PGPORT -// PGDATABASE -// PGUSER -// PGPASSWORD -// PGPASSFILE -// PGSSLMODE -// PGSSLCERT -// PGSSLKEY -// PGSSLROOTCERT -// PGAPPNAME -// PGCONNECT_TIMEOUT -// PGTARGETSESSIONATTRS -// -// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. -// -// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are -// usually but not always the environment variable name downcased and without the "PG" prefix. -// -// Important TLS Security Notes: -// -// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if -// not set. -// -// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of -// security each sslmode provides. -// -// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger -// security guarantees than it would with libpq. Do not rely on this behavior as it -// may be possible to match libpq in the future. If you need full security use -// "verify-full". -func ParseConfig(connString string) (*Config, error) { - settings := defaultSettings() - addEnvSettings(settings) - - if connString != "" { - // connString may be a database URL or a DSN - if strings.HasPrefix(connString, "postgres://") { - err := addURLSettings(settings, connString) - if err != nil { - return nil, err - } - } else { - err := addDSNSettings(settings, connString) - if err != nil { - return nil, err - } - } - } - - config := &Config{ - Database: settings["database"], - User: settings["user"], - Password: settings["password"], - RuntimeParams: make(map[string]string), - } - - if connectTimeout, present := settings["connect_timeout"]; present { - dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) - if err != nil { - return nil, err - } - config.DialFunc = dialFunc - } else { - defaultDialer := makeDefaultDialer() - config.DialFunc = defaultDialer.DialContext - } - - notRuntimeParams := map[string]struct{}{ - "host": struct{}{}, - "port": struct{}{}, - "database": struct{}{}, - "user": struct{}{}, - "password": struct{}{}, - "passfile": struct{}{}, - "connect_timeout": struct{}{}, - "sslmode": struct{}{}, - "sslkey": struct{}{}, - "sslcert": struct{}{}, - "sslrootcert": struct{}{}, - "target_session_attrs": struct{}{}, - } - - for k, v := range settings { - if _, present := notRuntimeParams[k]; present { - continue - } - config.RuntimeParams[k] = v - } - - fallbacks := []*FallbackConfig{} - - hosts := strings.Split(settings["host"], ",") - ports := strings.Split(settings["port"], ",") - - for i, host := range hosts { - var portStr string - if i < len(ports) { - portStr = ports[i] - } else { - portStr = ports[0] - } - - port, err := parsePort(portStr) - if err != nil { - return nil, fmt.Errorf("invalid port: %v", settings["port"]) - } - - var tlsConfigs []*tls.Config - - // Ignore TLS settings if Unix domain socket like libpq - if network, _ := NetworkAddress(host, port); network == "unix" { - tlsConfigs = append(tlsConfigs, nil) - } else { - var err error - tlsConfigs, err = configTLS(settings) - if err != nil { - return nil, err - } - } - - for _, tlsConfig := range tlsConfigs { - fallbacks = append(fallbacks, &FallbackConfig{ - Host: host, - Port: port, - TLSConfig: tlsConfig, - }) - } - } - - config.Host = fallbacks[0].Host - config.Port = fallbacks[0].Port - config.TLSConfig = fallbacks[0].TLSConfig - config.Fallbacks = fallbacks[1:] - - passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) - if err == nil { - if config.Password == "" { - host := config.Host - if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { - host = "localhost" - } - - config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) - } - } - - if settings["target_session_attrs"] == "read-write" { - config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite - } else if settings["target_session_attrs"] != "any" { - return nil, fmt.Errorf("unknown target_session_attrs value %v", settings["target_session_attrs"]) - } - - return config, nil -} - -func defaultSettings() map[string]string { - settings := make(map[string]string) - - settings["host"] = defaultHost() - settings["port"] = "5432" - - // Default to the OS user name. Purposely ignoring err getting user name from - // OS. The client application will simply have to specify the user in that - // case (which they typically will be doing anyway). - user, err := user.Current() - if err == nil { - settings["user"] = user.Username - settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") - } - - settings["target_session_attrs"] = "any" - - return settings -} - -// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost -// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it -// checks the existence of common locations. -func defaultHost() string { - candidatePaths := []string{ - "/var/run/postgresql", // Debian - "/private/tmp", // OSX - homebrew - "/tmp", // standard PostgreSQL - } - - for _, path := range candidatePaths { - if _, err := os.Stat(path); err == nil { - return path - } - } - - return "localhost" -} - -func addEnvSettings(settings map[string]string) { - nameMap := map[string]string{ - "PGHOST": "host", - "PGPORT": "port", - "PGDATABASE": "database", - "PGUSER": "user", - "PGPASSWORD": "password", - "PGPASSFILE": "passfile", - "PGAPPNAME": "application_name", - "PGCONNECT_TIMEOUT": "connect_timeout", - "PGSSLMODE": "sslmode", - "PGSSLKEY": "sslkey", - "PGSSLCERT": "sslcert", - "PGSSLROOTCERT": "sslrootcert", - "PGTARGETSESSIONATTRS": "target_session_attrs", - } - - for envname, realname := range nameMap { - value := os.Getenv(envname) - if value != "" { - settings[realname] = value - } - } -} - -func addURLSettings(settings map[string]string, connString string) error { - url, err := url.Parse(connString) - if err != nil { - return err - } - - if url.User != nil { - settings["user"] = url.User.Username() - if password, present := url.User.Password(); present { - settings["password"] = password - } - } - - // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. - var hosts []string - var ports []string - for _, host := range strings.Split(url.Host, ",") { - parts := strings.SplitN(host, ":", 2) - if parts[0] != "" { - hosts = append(hosts, parts[0]) - } - if len(parts) == 2 { - ports = append(ports, parts[1]) - } - } - if len(hosts) > 0 { - settings["host"] = strings.Join(hosts, ",") - } - if len(ports) > 0 { - settings["port"] = strings.Join(ports, ",") - } - - database := strings.TrimLeft(url.Path, "/") - if database != "" { - settings["database"] = database - } - - for k, v := range url.Query() { - settings[k] = v[0] - } - - return nil -} - -var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) - -func addDSNSettings(settings map[string]string, s string) error { - m := dsnRegexp.FindAllStringSubmatch(s, -1) - - for _, b := range m { - settings[b[1]] = b[2] - } - - return nil -} - -type pgTLSArgs struct { - sslMode string - sslRootCert string - sslCert string - sslKey string -} - -// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is -// necessary to allow returning multiple TLS configs as sslmode "allow" and -// "prefer" allow fallback. -func configTLS(settings map[string]string) ([]*tls.Config, error) { - host := settings["host"] - sslmode := settings["sslmode"] - sslrootcert := settings["sslrootcert"] - sslcert := settings["sslcert"] - sslkey := settings["sslkey"] - - // Match libpq default behavior - if sslmode == "" { - sslmode = "prefer" - } - - tlsConfig := &tls.Config{} - - switch sslmode { - case "disable": - return []*tls.Config{nil}, nil - case "allow", "prefer": - tlsConfig.InsecureSkipVerify = true - case "require": - tlsConfig.InsecureSkipVerify = sslrootcert == "" - case "verify-ca", "verify-full": - tlsConfig.ServerName = host - default: - return nil, errors.New("sslmode is invalid") - } - - if sslrootcert != "" { - caCertPool := x509.NewCertPool() - - caPath := sslrootcert - caCert, err := ioutil.ReadFile(caPath) - if err != nil { - return nil, errors.Wrapf(err, "unable to read CA file %q", caPath) - } - - if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, errors.Wrap(err, "unable to add CA to cert pool") - } - - tlsConfig.RootCAs = caCertPool - tlsConfig.ClientCAs = caCertPool - } - - if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { - return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`) - } - - if sslcert != "" && sslkey != "" { - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - return nil, errors.Wrap(err, "unable to read cert") - } - - tlsConfig.Certificates = []tls.Certificate{cert} - } - - switch sslmode { - case "allow": - return []*tls.Config{nil, tlsConfig}, nil - case "prefer": - return []*tls.Config{tlsConfig, nil}, nil - case "require", "verify-ca", "verify-full": - return []*tls.Config{tlsConfig}, nil - default: - panic("BUG: bad sslmode should already have been caught") - } -} - -func parsePort(s string) (uint16, error) { - port, err := strconv.ParseUint(s, 10, 16) - if err != nil { - return 0, err - } - if port < 1 || port > math.MaxUint16 { - return 0, errors.New("outside range") - } - return uint16(port), nil -} - -func makeDefaultDialer() *net.Dialer { - return &net.Dialer{KeepAlive: 5 * time.Minute} -} - -func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { - timeout, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return nil, err - } - if timeout < 0 { - return nil, errors.New("negative timeout") - } - - d := makeDefaultDialer() - d.Timeout = time.Duration(timeout) * time.Second - return d.DialContext, nil -} - -// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible -// target_session_attrs=read-write. -func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err - } - - if string(result.Rows[0][0]) == "on" { - return errors.New("read only connection") - } - - return nil -} diff --git a/pgconn/config_test.go b/pgconn/config_test.go deleted file mode 100644 index c7b65861..00000000 --- a/pgconn/config_test.go +++ /dev/null @@ -1,562 +0,0 @@ -package pgconn_test - -import ( - "crypto/tls" - "fmt" - "io/ioutil" - "os" - "os/user" - "testing" - - "github.com/jackc/pgx/pgconn" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseConfig(t *testing.T) { - t.Parallel() - - var osUserName string - osUser, err := user.Current() - if err == nil { - osUserName = osUser.Username - } - - tests := []struct { - name string - connString string - config *pgconn.Config - }{ - // Test all sslmodes - { - name: "sslmode not set (prefer)", - connString: "postgres://jack:secret@localhost:5432/mydb", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - RuntimeParams: map[string]string{}, - Fallbacks: []*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 5432, - TLSConfig: nil, - }, - }, - }, - }, - { - name: "sslmode disable", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode allow", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - Fallbacks: []*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 5432, - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - }, - }, - }, - { - name: "sslmode prefer", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", - config: &pgconn.Config{ - - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - RuntimeParams: map[string]string{}, - Fallbacks: []*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 5432, - TLSConfig: nil, - }, - }, - }, - }, - { - name: "sslmode require", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode verify-ca", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ServerName: "localhost"}, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode verify-full", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ServerName: "localhost"}, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "database url everything", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - }, - }, - }, - { - name: "database url missing password", - connString: "postgres://jack@localhost:5432/mydb?sslmode=disable", - config: &pgconn.Config{ - User: "jack", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "database url missing user and password", - connString: "postgres://localhost:5432/mydb?sslmode=disable", - config: &pgconn.Config{ - User: osUserName, - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "database url missing port", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "database url unix domain socket host", - connString: "postgres:///foo?host=/tmp", - config: &pgconn.Config{ - User: osUserName, - Host: "/tmp", - Port: 5432, - Database: "foo", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "DSN everything", - connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - }, - }, - }, - { - name: "URL multiple hosts", - connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "foo", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - Fallbacks: []*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "bar", - Port: 5432, - TLSConfig: nil, - }, - &pgconn.FallbackConfig{ - Host: "baz", - Port: 5432, - TLSConfig: nil, - }, - }, - }, - }, - { - name: "URL multiple hosts and ports", - connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "foo", - Port: 1, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - Fallbacks: []*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "bar", - Port: 2, - TLSConfig: nil, - }, - &pgconn.FallbackConfig{ - Host: "baz", - Port: 3, - TLSConfig: nil, - }, - }, - }, - }, - { - name: "DSN multiple hosts one port", - connString: "user=jack password=secret host=foo,bar,baz port=5432 database=mydb sslmode=disable", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "foo", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - Fallbacks: []*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "bar", - Port: 5432, - TLSConfig: nil, - }, - &pgconn.FallbackConfig{ - Host: "baz", - Port: 5432, - TLSConfig: nil, - }, - }, - }, - }, - { - name: "DSN multiple hosts multiple ports", - connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 database=mydb sslmode=disable", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "foo", - Port: 1, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - Fallbacks: []*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "bar", - Port: 2, - TLSConfig: nil, - }, - &pgconn.FallbackConfig{ - Host: "baz", - Port: 3, - TLSConfig: nil, - }, - }, - }, - }, - { - name: "multiple hosts and fallback tsl", - connString: "user=jack password=secret host=foo,bar,baz database=mydb sslmode=prefer", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "foo", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - RuntimeParams: map[string]string{}, - Fallbacks: []*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "foo", - Port: 5432, - TLSConfig: nil, - }, - &pgconn.FallbackConfig{ - Host: "bar", - Port: 5432, - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }}, - &pgconn.FallbackConfig{ - Host: "bar", - Port: 5432, - TLSConfig: nil, - }, - &pgconn.FallbackConfig{ - Host: "baz", - Port: 5432, - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }}, - &pgconn.FallbackConfig{ - Host: "baz", - Port: 5432, - TLSConfig: nil, - }, - }, - }, - }, - { - name: "target_session_attrs", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", - config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite, - }, - }, - } - - for i, tt := range tests { - config, err := pgconn.ParseConfig(tt.connString) - if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { - continue - } - - assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) - } -} - -func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { - if !assert.NotNil(t, expected) { - return - } - if !assert.NotNil(t, actual) { - return - } - - assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) - assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) - assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) - assert.Equalf(t, expected.User, actual.User, "%s - User", testName) - assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) - assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) - - // Can't test function equality, so just test that they are set or not. - assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName) - - if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { - if expected.TLSConfig != nil { - assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) - assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) - } - } - - if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { - for i := range expected.Fallbacks { - assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) - assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) - - if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { - if expected.Fallbacks[i].TLSConfig != nil { - assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) - assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) - } - } - } - } -} - -func TestParseConfigEnvLibpq(t *testing.T) { - var osUserName string - osUser, err := user.Current() - if err == nil { - osUserName = osUser.Username - } - - pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} - - savedEnv := make(map[string]string) - for _, n := range pgEnvvars { - savedEnv[n] = os.Getenv(n) - } - defer func() { - for k, v := range savedEnv { - err := os.Setenv(k, v) - if err != nil { - t.Fatalf("Unable to restore environment: %v", err) - } - } - }() - - tests := []struct { - name string - envvars map[string]string - config *pgconn.Config - }{ - { - // not testing no environment at all as that would use default host and that can vary. - name: "PGHOST only", - envvars: map[string]string{"PGHOST": "123.123.123.123"}, - config: &pgconn.Config{ - User: osUserName, - Host: "123.123.123.123", - Port: 5432, - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - RuntimeParams: map[string]string{}, - Fallbacks: []*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "123.123.123.123", - Port: 5432, - TLSConfig: nil, - }, - }, - }, - }, - { - name: "All non-TLS environment", - envvars: map[string]string{ - "PGHOST": "123.123.123.123", - "PGPORT": "7777", - "PGDATABASE": "foo", - "PGUSER": "bar", - "PGPASSWORD": "baz", - "PGCONNECT_TIMEOUT": "10", - "PGSSLMODE": "disable", - "PGAPPNAME": "pgxtest", - }, - config: &pgconn.Config{ - Host: "123.123.123.123", - Port: 7777, - Database: "foo", - User: "bar", - Password: "baz", - TLSConfig: nil, - RuntimeParams: map[string]string{"application_name": "pgxtest"}, - }, - }, - } - - for i, tt := range tests { - for _, n := range pgEnvvars { - err := os.Unsetenv(n) - require.NoError(t, err) - } - - for k, v := range tt.envvars { - err := os.Setenv(k, v) - require.NoError(t, err) - } - - config, err := pgconn.ParseConfig("") - if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { - continue - } - - assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) - } -} - -func TestParseConfigReadsPgPassfile(t *testing.T) { - t.Parallel() - - tf, err := ioutil.TempFile("", "") - require.NoError(t, err) - - defer tf.Close() - defer os.Remove(tf.Name()) - - _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) - require.NoError(t, err) - - connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) - expected := &pgconn.Config{ - User: "curly", - Password: "nyuknyuknyuk", - Host: "test1", - Port: 5432, - Database: "curlydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - } - - actual, err := pgconn.ParseConfig(connString) - assert.NoError(t, err) - - assertConfigsEqual(t, expected, actual, "passfile") -} diff --git a/pgconn/doc.go b/pgconn/doc.go deleted file mode 100644 index 89e47536..00000000 --- a/pgconn/doc.go +++ /dev/null @@ -1,29 +0,0 @@ -// Package pgconn is a low-level PostgreSQL database driver. -/* -pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at -nearly the same level is the C library libpq. - -Establishing a Connection - -Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for -libpq style environment variables. - -Executing a Query - -ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method -reads all rows into memory. - -Executing Multiple Queries in a Single Round Trip - -Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query -result. The ReadAll method reads all query results into memory. - -Context Support - -All potentially blocking operations take a context.Context. If a context is canceled while a query is in progress the -method immediately returns. In the background a cancel request will be sent to the PostgreSQL server. If the -cancellation fails or hangs for more than a short time (approximately 15 seconds) the connection will be closed. It is -safe to use the connection while this background cancellation is in progress. Any calls will block until the -cancellation and resynchronization is complete (and those calls can be aborted by a context cancellation). -*/ -package pgconn diff --git a/pgconn/helper_test.go b/pgconn/helper_test.go deleted file mode 100644 index c5ac6e01..00000000 --- a/pgconn/helper_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package pgconn_test - -import ( - "context" - "testing" - "time" - - "github.com/jackc/pgx/pgconn" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func closeConn(t testing.TB, conn *pgconn.PgConn) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - require.Nil(t, conn.Close(ctx)) -} - -// Do a simple query to ensure the connection is still usable -func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read() - cancel() - - require.Nil(t, result.Err) - assert.Equal(t, 3, len(result.Rows)) - assert.Equal(t, "1", string(result.Rows[0][0])) - assert.Equal(t, "2", string(result.Rows[1][0])) - assert.Equal(t, "3", string(result.Rows[2][0])) -} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go deleted file mode 100644 index c785f367..00000000 --- a/pgconn/pgconn.go +++ /dev/null @@ -1,1407 +0,0 @@ -package pgconn - -import ( - "context" - "crypto/md5" - "crypto/tls" - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - "io" - "net" - "strconv" - "strings" - "sync" - "time" - - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgproto3" -) - -var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) - -// PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for -// detailed field description. -type PgError struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string -} - -func (pe *PgError) Error() string { - return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" -} - -// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from -// LISTEN/NOTIFY notification. -type Notice PgError - -// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system -type Notification struct { - PID uint32 // backend pid that sent the notification - Channel string // channel from which notification was received - Payload string -} - -// DialFunc is a function that can be used to connect to a PostgreSQL server -type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) - -// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at -// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin -// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY -// notification. -type NoticeHandler func(*PgConn, *Notice) - -// NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications -// can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is -// aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a -// notice event. -type NotificationHandler func(*PgConn, *Notification) - -// ErrTLSRefused occurs when the connection attempt requires TLS and the -// PostgreSQL server refuses to use TLS -var ErrTLSRefused = errors.New("server refused TLS connection") - -// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. -type PgConn struct { - conn net.Conn // the underlying TCP or unix domain socket connection - pid uint32 // backend pid - secretKey uint32 // key to use to send a cancel query message to the server - parameterStatuses map[string]string // parameters that have been reported by the server - TxStatus byte - Frontend *pgproto3.Frontend - - Config *Config - - controller chan interface{} - - closed bool - - bufferingReceive bool - bufferingReceiveMux sync.Mutex - bufferingReceiveMsg pgproto3.BackendMessage - bufferingReceiveErr error -} - -// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt. -func Connect(ctx context.Context, connString string) (*PgConn, error) { - config, err := ParseConfig(connString) - if err != nil { - return nil, err - } - - return ConnectConfig(ctx, config) -} - -// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. -// -// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An -// authentication error will terminate the chain of attempts (like libpq: -// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, -// if all attempts fail the last error is returned. -func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { - // For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. - if config.Port == 0 { - config.Port = 5432 - } - if config.DialFunc == nil { - config.DialFunc = makeDefaultDialer().DialContext - } - if config.RuntimeParams == nil { - config.RuntimeParams = make(map[string]string) - } - - // Simplify usage by treating primary config and fallbacks the same. - fallbackConfigs := []*FallbackConfig{ - { - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - } - fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) - - for _, fc := range fallbackConfigs { - pgConn, err = connect(ctx, config, fc) - if err == nil { - return pgConn, nil - } else if err, ok := err.(*PgError); ok { - return nil, err - } - } - - return nil, err -} - -func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { - pgConn := new(PgConn) - pgConn.Config = config - pgConn.controller = make(chan interface{}, 1) - - var err error - network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - pgConn.conn, err = config.DialFunc(ctx, network, address) - if err != nil { - return nil, err - } - - pgConn.parameterStatuses = make(map[string]string) - - if config.TLSConfig != nil { - if err := pgConn.startTLS(config.TLSConfig); err != nil { - pgConn.conn.Close() - return nil, err - } - } - - pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.conn, pgConn.conn) - if err != nil { - return nil, err - } - - startupMsg := pgproto3.StartupMessage{ - ProtocolVersion: pgproto3.ProtocolVersionNumber, - Parameters: make(map[string]string), - } - - // Copy default run-time params - for k, v := range config.RuntimeParams { - startupMsg.Parameters[k] = v - } - - startupMsg.Parameters["user"] = config.User - if config.Database != "" { - startupMsg.Parameters["database"] = config.Database - } - - if _, err := pgConn.conn.Write(startupMsg.Encode(nil)); err != nil { - pgConn.conn.Close() - return nil, err - } - - for { - msg, err := pgConn.ReceiveMessage() - if err != nil { - return nil, err - } - - switch msg := msg.(type) { - case *pgproto3.BackendKeyData: - pgConn.pid = msg.ProcessID - pgConn.secretKey = msg.SecretKey - case *pgproto3.Authentication: - if err = pgConn.rxAuthenticationX(msg); err != nil { - pgConn.conn.Close() - return nil, err - } - case *pgproto3.ReadyForQuery: - if config.AfterConnectFunc != nil { - err := config.AfterConnectFunc(ctx, pgConn) - if err != nil { - pgConn.conn.Close() - return nil, fmt.Errorf("AfterConnectFunc: %v", err) - } - } - return pgConn, nil - case *pgproto3.ParameterStatus: - // handled by ReceiveMessage - case *pgproto3.ErrorResponse: - pgConn.conn.Close() - return nil, errorResponseToPgError(msg) - default: - pgConn.conn.Close() - return nil, errors.New("unexpected message") - } - } -} - -func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) - if err != nil { - return - } - - response := make([]byte, 1) - if _, err = io.ReadFull(pgConn.conn, response); err != nil { - return - } - - if response[0] != 'S' { - return ErrTLSRefused - } - - pgConn.conn = tls.Client(pgConn.conn, tlsConfig) - - return nil -} - -func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { - switch msg.Type { - case pgproto3.AuthTypeOk: - case pgproto3.AuthTypeCleartextPassword: - err = c.txPasswordMessage(c.Config.Password) - case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:])) - err = c.txPasswordMessage(digestedPassword) - default: - err = errors.New("Received unknown authentication message") - } - - return -} - -func (pgConn *PgConn) txPasswordMessage(password string) (err error) { - msg := &pgproto3.PasswordMessage{Password: password} - _, err = pgConn.conn.Write(msg.Encode(nil)) - return err -} - -func hexMD5(s string) string { - hash := md5.New() - io.WriteString(hash, s) - return hex.EncodeToString(hash.Sum(nil)) -} - -func (pgConn *PgConn) signalMessage() chan struct{} { - if pgConn.bufferingReceive { - panic("BUG: signalMessage when already in progress") - } - - pgConn.bufferingReceive = true - pgConn.bufferingReceiveMux.Lock() - - ch := make(chan struct{}) - go func() { - pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive() - pgConn.bufferingReceiveMux.Unlock() - close(ch) - }() - - return ch -} - -func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { - var msg pgproto3.BackendMessage - var err error - if pgConn.bufferingReceive { - pgConn.bufferingReceiveMux.Lock() - msg = pgConn.bufferingReceiveMsg - err = pgConn.bufferingReceiveErr - pgConn.bufferingReceiveMux.Unlock() - pgConn.bufferingReceive = false - - // If a timeout error happened in the background try the read again. - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - msg, err = pgConn.Frontend.Receive() - } - } else { - msg, err = pgConn.Frontend.Receive() - } - - if err != nil { - // Close on anything other than timeout error - everything else is fatal - if err, ok := err.(net.Error); !(ok && err.Timeout()) { - pgConn.hardClose() - } - - return nil, err - } - - switch msg := msg.(type) { - case *pgproto3.ReadyForQuery: - pgConn.TxStatus = msg.TxStatus - case *pgproto3.ParameterStatus: - pgConn.parameterStatuses[msg.Name] = msg.Value - case *pgproto3.ErrorResponse: - if msg.Severity == "FATAL" { - pgConn.hardClose() - return nil, errorResponseToPgError(msg) - } - case *pgproto3.NoticeResponse: - if pgConn.Config.OnNotice != nil { - pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) - } - case *pgproto3.NotificationResponse: - if pgConn.Config.OnNotification != nil { - pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) - } - } - - return msg, nil -} - -// Conn returns the underlying net.Conn. -func (pgConn *PgConn) Conn() net.Conn { - return pgConn.conn -} - -// PID returns the backend PID. -func (pgConn *PgConn) PID() uint32 { - return pgConn.pid -} - -// SecretKey returns the backend secret key used to send a cancel query message to the server. -func (pgConn *PgConn) SecretKey() uint32 { - return pgConn.secretKey -} - -// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by -// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The -// underlying net.Conn.Close() will always be called regardless of any other errors. -func (pgConn *PgConn) Close(ctx context.Context) error { - if pgConn.closed { - return nil - } - pgConn.closed = true - - defer pgConn.conn.Close() - - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContext() - - _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) - if err != nil { - return preferContextOverNetTimeoutError(ctx, err) - } - - _, err = pgConn.conn.Read(make([]byte, 1)) - if err != io.EOF { - return preferContextOverNetTimeoutError(ctx, err) - } - - return pgConn.conn.Close() -} - -// hardClose closes the underlying connection without sending the exit message. -func (pgConn *PgConn) hardClose() error { - if pgConn.closed { - return nil - } - pgConn.closed = true - return pgConn.conn.Close() -} - -// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of -// underlying connection. -func (pgConn *PgConn) IsAlive() bool { - return !pgConn.closed -} - -// ParameterStatus returns the value of a parameter reported by the server (e.g. -// server_version). Returns an empty string for unknown parameters. -func (pgConn *PgConn) ParameterStatus(key string) string { - return pgConn.parameterStatuses[key] -} - -// CommandTag is the result of an Exec function -type CommandTag string - -// RowsAffected returns the number of rows affected. If the CommandTag was not -// for a row affecting command (e.g. "CREATE TABLE") then it returns 0. -func (ct CommandTag) RowsAffected() int64 { - s := string(ct) - index := strings.LastIndex(s, " ") - if index == -1 { - return 0 - } - n, _ := strconv.ParseInt(s[index+1:], 10, 64) - return n -} - -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return ctx.Err() - } - return err -} - -// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from -// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to -// call multiple times. -func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { - if ctx.Done() != nil { - deadlineWasSet := false - doneChan := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - conn.SetDeadline(deadlineTime) - deadlineWasSet = true - <-doneChan - case <-doneChan: - } - }() - - finished := false - return func() { - if !finished { - doneChan <- struct{}{} - if deadlineWasSet { - conn.SetDeadline(time.Time{}) - } - finished = true - } - } - } - - return func() {} -} - -type PreparedStatementDescription struct { - Name string - SQL string - ParamOIDs []uint32 - Fields []pgproto3.FieldDescription -} - -// Prepare creates a prepared statement. -func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case pgConn.controller <- pgConn: - } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() - - var buf []byte - buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) - buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) - - _, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) - } - - psd := &PreparedStatementDescription{Name: name, SQL: sql} - - var parseErr error - -readloop: - for { - msg, err := pgConn.ReceiveMessage() - if err != nil { - go pgConn.recoverFromTimeout() - return nil, preferContextOverNetTimeoutError(ctx, err) - } - - switch msg := msg.(type) { - case *pgproto3.ParameterDescription: - psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) - copy(psd.ParamOIDs, msg.ParameterOIDs) - case *pgproto3.RowDescription: - psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) - copy(psd.Fields, msg.Fields) - case *pgproto3.ErrorResponse: - parseErr = errorResponseToPgError(msg) - case *pgproto3.ReadyForQuery: - break readloop - } - } - - <-pgConn.controller - - if parseErr != nil { - return nil, parseErr - } - return psd, nil -} - -func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { - return &PgError{ - Severity: string(msg.Severity), - Code: string(msg.Code), - Message: string(msg.Message), - Detail: string(msg.Detail), - Hint: string(msg.Hint), - Position: msg.Position, - InternalPosition: msg.InternalPosition, - InternalQuery: string(msg.InternalQuery), - Where: string(msg.Where), - SchemaName: string(msg.SchemaName), - TableName: string(msg.TableName), - ColumnName: string(msg.ColumnName), - DataTypeName: string(msg.DataTypeName), - ConstraintName: string(msg.ConstraintName), - File: string(msg.File), - Line: msg.Line, - Routine: string(msg.Routine), - } -} - -func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { - pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg)) - return (*Notice)(pgerr) -} - -// cancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel -// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there -// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 -func (pgConn *PgConn) cancelRequest(ctx context.Context) error { - // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing - // the connection config. This is important in high availability configurations where fallback connections may be - // specified or DNS may be used to load balance. - serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) - if err != nil { - return err - } - defer cancelConn.Close() - - cleanupContext := contextDoneToConnDeadline(ctx, cancelConn) - defer cleanupContext() - - buf := make([]byte, 16) - binary.BigEndian.PutUint32(buf[0:4], 16) - binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) - binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) - _, err = cancelConn.Write(buf) - if err != nil { - return preferContextOverNetTimeoutError(ctx, err) - } - - _, err = cancelConn.Read(buf) - if err != io.EOF { - return fmt.Errorf("Server failed to close connection after cancel query request: %v", preferContextOverNetTimeoutError(ctx, err)) - } - - return nil -} - -// WaitUntilReady waits until a previous context cancellation has been completed and the connection is ready for use. -// This is done automatically by all methods that need the connection to be ready for use. The only expected use for -// this method is for a connection pool to wait for a returned connection to be usable again before making it available. -func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { - select { - case <-ctx.Done(): - return ctx.Err() - case pgConn.controller <- pgConn: - // The connection must be ready since it was locked. Immediately unlock it. - <-pgConn.controller - } - - return nil -} - -// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not -// received. -func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { - select { - case <-ctx.Done(): - return ctx.Err() - case pgConn.controller <- pgConn: - } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() - defer func() { <-pgConn.controller }() - - for { - msg, err := pgConn.ReceiveMessage() - if err != nil { - return preferContextOverNetTimeoutError(ctx, err) - } - - switch msg.(type) { - case *pgproto3.NotificationResponse: - return nil - } - } -} - -// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is -// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control -// statements. -// -// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. -func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { - multiResult := &MultiResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, - } - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = ctx.Err() - return multiResult - case pgConn.controller <- multiResult: - } - multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - - var buf []byte - buf = (&pgproto3.Query{String: sql}).Encode(buf) - - _, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - multiResult.cleanupContextDeadline() - multiResult.closed = true - multiResult.err = preferContextOverNetTimeoutError(ctx, err) - <-pgConn.controller - return multiResult - } - - return multiResult -} - -// ExecParams executes a command via the PostgreSQL extended query protocol. -// -// sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, -// etc. -// -// paramValues are the parameter values. It must be encoded in the format given by paramFormats. -// -// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for -// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. -// ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). -// -// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. ExecParams will panic if -// len(paramFormats) is not 0, 1, or len(paramValues). -// -// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. -// -// ResultReader must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { - result := &ResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, - } - - select { - case <-ctx.Done(): - result.concludeCommand("", ctx.Err()) - result.closed = true - return result - case pgConn.controller <- result: - } - result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - - var buf []byte - - // TODO - refactor ExecParams and ExecPrepared - these lines only difference - buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) - buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - - buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) - buf = (&pgproto3.Execute{}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) - - _, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - result.concludeCommand("", err) - result.cleanupContextDeadline() - result.closed = true - <-pgConn.controller - } - - return result -} - -// ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. -// -// paramValues are the parameter values. It must be encoded in the format given by paramFormats. -// -// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. ExecPrepared will panic if -// len(paramFormats) is not 0, 1, or len(paramValues). -// -// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. -// -// ResultReader must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { - result := &ResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, - } - - select { - case <-ctx.Done(): - result.concludeCommand("", ctx.Err()) - result.closed = true - return result - case pgConn.controller <- result: - } - result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - - var buf []byte - buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) - buf = (&pgproto3.Execute{}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) - - _, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - result.concludeCommand("", err) - result.cleanupContextDeadline() - result.closed = true - <-pgConn.controller - } - - return result -} - -// CopyTo executes the copy command sql and copies the results to w. -func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { - select { - case <-ctx.Done(): - return "", ctx.Err() - case pgConn.controller <- pgConn: - } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - - // Send copy to command - var buf []byte - buf = (&pgproto3.Query{String: sql}).Encode(buf) - - _, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - cleanupContextDeadline() - <-pgConn.controller - - return "", preferContextOverNetTimeoutError(ctx, err) - } - - // Read results - var commandTag CommandTag - var pgErr error - for { - msg, err := pgConn.ReceiveMessage() - if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() - } else { - <-pgConn.controller - } - - return "", preferContextOverNetTimeoutError(ctx, err) - } - - switch msg := msg.(type) { - case *pgproto3.CopyDone: - case *pgproto3.CopyData: - _, err := w.Write(msg.Data) - if err != nil { - // This isn't actually a timeout, but we want the same behavior. Abort the request and cleanup. - cleanupContextDeadline() - go pgConn.recoverFromTimeout() - return "", err - } - case *pgproto3.ReadyForQuery: - <-pgConn.controller - return commandTag, pgErr - case *pgproto3.CommandComplete: - commandTag = CommandTag(msg.CommandTag) - case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) - } - } -} - -// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. -// -// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r -// could still block. -func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { - select { - case <-ctx.Done(): - return "", ctx.Err() - case pgConn.controller <- pgConn: - } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - - // Send copy to command - var buf []byte - buf = (&pgproto3.Query{String: sql}).Encode(buf) - - _, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - cleanupContextDeadline() - <-pgConn.controller - - return "", preferContextOverNetTimeoutError(ctx, err) - } - - // Read until copy in response or error. - var commandTag CommandTag - var pgErr error - pendingCopyInResponse := true - for pendingCopyInResponse { - msg, err := pgConn.ReceiveMessage() - if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } - - return "", preferContextOverNetTimeoutError(ctx, err) - } - - switch msg := msg.(type) { - case *pgproto3.CopyInResponse: - pendingCopyInResponse = false - case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) - case *pgproto3.ReadyForQuery: - <-pgConn.controller - return commandTag, pgErr - } - } - - // Send copy data - buf = make([]byte, 0, 20000) - // buf = make([]byte, 0, 65536) - buf = append(buf, 'd') - sp := len(buf) - var readErr error - signalMessageChan := pgConn.signalMessage() - for readErr == nil && pgErr == nil { - var n int - n, readErr = r.Read(buf[5:cap(buf)]) - if n > 0 { - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) - - _, err = pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - cleanupContextDeadline() - <-pgConn.controller - - return "", preferContextOverNetTimeoutError(ctx, err) - } - } - - select { - case <-signalMessageChan: - msg, err := pgConn.ReceiveMessage() - if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } - - return "", preferContextOverNetTimeoutError(ctx, err) - } - - switch msg := msg.(type) { - case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) - } - default: - } - } - - buf = buf[:0] - if readErr == io.EOF || pgErr != nil { - copyDone := &pgproto3.CopyDone{} - buf = copyDone.Encode(buf) - } else { - copyFail := &pgproto3.CopyFail{Error: readErr.Error()} - buf = copyFail.Encode(buf) - } - _, err = pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - - cleanupContextDeadline() - <-pgConn.controller - - return "", preferContextOverNetTimeoutError(ctx, err) - } - - // Read results - for { - msg, err := pgConn.ReceiveMessage() - if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() - } else { - <-pgConn.controller - } - - return "", preferContextOverNetTimeoutError(ctx, err) - } - - switch msg := msg.(type) { - case *pgproto3.ReadyForQuery: - <-pgConn.controller - return commandTag, pgErr - case *pgproto3.CommandComplete: - commandTag = CommandTag(msg.CommandTag) - case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) - } - } -} - -func (pgConn *PgConn) recoverFromTimeoutDuringCopyFrom() { - // Regardless of recovery outcome the lock on the pgConn must be released. - defer func() { <-pgConn.controller }() - - // Limit time to wait for entire cancellation process. - err := pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) - if err != nil { - pgConn.hardClose() - return - } - - copyFail := &pgproto3.CopyFail{Error: "client cancel"} - buf := copyFail.Encode(nil) - - _, err = pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - return - } - - pendingReadyForQuery := true - - for pendingReadyForQuery { - msg, err := pgConn.ReceiveMessage() - if err != nil { - pgConn.hardClose() - return - } - - switch msg.(type) { - case *pgproto3.ReadyForQuery: - pendingReadyForQuery = false - } - } - - err = pgConn.conn.SetDeadline(time.Time{}) - if err != nil { - pgConn.hardClose() - } -} - -// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. -type MultiResultReader struct { - pgConn *PgConn - ctx context.Context - cleanupContextDeadline func() - - rr *ResultReader - - closed bool - err error -} - -// ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods. -func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { - var results []*Result - - for mrr.NextResult() { - results = append(results, mrr.ResultReader().Read()) - } - err := mrr.Close() - - return results, err -} - -func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { - msg, err := mrr.pgConn.ReceiveMessage() - - if err != nil { - mrr.cleanupContextDeadline() - mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) - mrr.closed = true - - if err, ok := err.(net.Error); ok && err.Timeout() { - go mrr.pgConn.recoverFromTimeout() - } else { - <-mrr.pgConn.controller - } - - return nil, mrr.err - } - - switch msg := msg.(type) { - case *pgproto3.ReadyForQuery: - mrr.cleanupContextDeadline() - mrr.closed = true - <-mrr.pgConn.controller - case *pgproto3.ErrorResponse: - mrr.err = errorResponseToPgError(msg) - } - - return msg, nil -} - -// NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. -func (mrr *MultiResultReader) NextResult() bool { - for !mrr.closed && mrr.err == nil { - msg, err := mrr.receiveMessage() - if err != nil { - return false - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - mrr.rr = &ResultReader{ - pgConn: mrr.pgConn, - multiResultReader: mrr, - ctx: mrr.ctx, - cleanupContextDeadline: func() {}, - fieldDescriptions: msg.Fields, - } - return true - case *pgproto3.CommandComplete: - mrr.rr = &ResultReader{ - commandTag: CommandTag(msg.CommandTag), - commandConcluded: true, - closed: true, - } - return true - case *pgproto3.EmptyQueryResponse: - return false - } - } - - return false -} - -// ResultReader returns the current ResultReader. -func (mrr *MultiResultReader) ResultReader() *ResultReader { - return mrr.rr -} - -// Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use. -func (mrr *MultiResultReader) Close() error { - for !mrr.closed { - _, err := mrr.receiveMessage() - if err != nil { - return mrr.err - } - } - - return mrr.err -} - -// ResultReader is a reader for the result of a single query. -type ResultReader struct { - pgConn *PgConn - multiResultReader *MultiResultReader - ctx context.Context - cleanupContextDeadline func() - - fieldDescriptions []pgproto3.FieldDescription - rowValues [][]byte - commandTag CommandTag - commandConcluded bool - closed bool - err error -} - -// Result is the saved query response that is returned by calling Read on a ResultReader. -type Result struct { - FieldDescriptions []pgproto3.FieldDescription - Rows [][][]byte - CommandTag CommandTag - Err error -} - -// Read saves the query response to a Result. -func (rr *ResultReader) Read() *Result { - br := &Result{} - - for rr.NextRow() { - if br.FieldDescriptions == nil { - br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions())) - copy(br.FieldDescriptions, rr.FieldDescriptions()) - } - - row := make([][]byte, len(rr.Values())) - copy(row, rr.Values()) - br.Rows = append(br.Rows, row) - } - - br.CommandTag, br.Err = rr.Close() - - return br -} - -// NextRow advances the ResultReader to the next row and returns true if a row is available. -func (rr *ResultReader) NextRow() bool { - for !rr.commandConcluded { - msg, err := rr.receiveMessage() - if err != nil { - return false - } - - switch msg := msg.(type) { - case *pgproto3.DataRow: - rr.rowValues = msg.Values - return true - } - } - - return false -} - -// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the ResultReader is closed. -func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription { - return rr.fieldDescriptions -} - -// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only -// valid until the next NextRow call or the ResultReader is closed. However, the underlying byte data is safe to -// retain a reference to and mutate. -func (rr *ResultReader) Values() [][]byte { - return rr.rowValues -} - -// Close consumes any remaining result data and returns the command tag or -// error. -func (rr *ResultReader) Close() (CommandTag, error) { - if rr.closed { - return rr.commandTag, rr.err - } - rr.closed = true - - for !rr.commandConcluded { - _, err := rr.receiveMessage() - if err != nil { - return "", rr.err - } - } - - if rr.multiResultReader == nil { - for { - msg, err := rr.receiveMessage() - if err != nil { - return "", rr.err - } - - switch msg.(type) { - case *pgproto3.ReadyForQuery: - rr.cleanupContextDeadline() - <-rr.pgConn.controller - return rr.commandTag, rr.err - } - } - } - - return rr.commandTag, rr.err -} - -func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { - if rr.multiResultReader == nil { - msg, err = rr.pgConn.ReceiveMessage() - } else { - msg, err = rr.multiResultReader.receiveMessage() - } - - if err != nil { - rr.concludeCommand("", err) - rr.cleanupContextDeadline() - rr.closed = true - if rr.multiResultReader == nil { - if err, ok := err.(net.Error); ok && err.Timeout() { - go rr.pgConn.recoverFromTimeout() - } else { - <-rr.pgConn.controller - } - } - - return nil, rr.err - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - rr.fieldDescriptions = msg.Fields - case *pgproto3.CommandComplete: - rr.concludeCommand(CommandTag(msg.CommandTag), nil) - case *pgproto3.ErrorResponse: - rr.concludeCommand("", errorResponseToPgError(msg)) - } - - return msg, nil -} - -func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { - if rr.commandConcluded { - return - } - - rr.commandTag = commandTag - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) - rr.fieldDescriptions = nil - rr.rowValues = nil - rr.commandConcluded = true -} - -func (pgConn *PgConn) defaultCancel() { - // Regardless of recovery outcome the lock on the pgConn must be released. - defer func() { <-pgConn.controller }() - - // Send a cancellation request to the PostgreSQL server. If it is not successful in a reasonable amount of time do not - // try further to recover the connection. - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - err := pgConn.cancelRequest(ctx) - cancel() - if err != nil { - pgConn.hardClose() - return - } - - // Limit time to wait for ReadyForQuery message. - err = pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) - if err != nil { - pgConn.hardClose() - return - } - - // A cancel query request will always return a "57014" error response, even if no query was in progress. This error - // may be returned before or after the ReadyForQuery message. Must ensure both messages are read. - needError57014 := true - needReadyForQuery := true - - for needError57014 || needReadyForQuery { - msg, err := pgConn.ReceiveMessage() - if err != nil { - pgConn.hardClose() - return - } - - switch msg := msg.(type) { - case *pgproto3.ErrorResponse: - if msg.Code == "57014" { - needError57014 = false - } - case *pgproto3.ReadyForQuery: - needReadyForQuery = false - } - } - - err = pgConn.conn.SetDeadline(time.Time{}) - if err != nil { - pgConn.hardClose() - } -} - -type ContextCancel struct { - PgConn *PgConn -} - -// Finish must be called when the cancellation request has finished processing. The connection must be in a ready for -// query state or the connection must be closed. This must be called regardless of the success of the cancellation and -// whether the connection is still valid or not. It releases an internal busy lock on the connection. -func (cc *ContextCancel) Finish() { - <-cc.PgConn.controller -} - -func (pgConn *PgConn) recoverFromTimeout() { - if pgConn.Config.OnContextCancel == nil { - pgConn.defaultCancel() - } else { - cc := &ContextCancel{PgConn: pgConn} - pgConn.Config.OnContextCancel(cc) - } -} - -// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. -type Batch struct { - buf []byte -} - -// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. -func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) - batch.ExecPrepared("", paramValues, paramFormats, resultFormats) -} - -// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. -func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) - batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) - batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) -} - -// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a -// transaction is already in progress or SQL contains transaction control statements. -func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { - multiResult := &MultiResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, - } - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = ctx.Err() - return multiResult - case pgConn.controller <- multiResult: - } - multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - - batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - _, err := pgConn.conn.Write(batch.buf) - if err != nil { - pgConn.hardClose() - multiResult.cleanupContextDeadline() - multiResult.closed = true - multiResult.err = preferContextOverNetTimeoutError(ctx, err) - <-pgConn.controller - return multiResult - } - - return multiResult -} - -// EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include -// the surrounding single quotes. -// -// The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these -// conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future. -func (pgConn *PgConn) EscapeString(s string) (string, error) { - if pgConn.ParameterStatus("standard_conforming_strings") != "on" { - return "", errors.New("EscapeString must be run with standard_conforming_strings=on") - } - - if pgConn.ParameterStatus("client_encoding") != "UTF8" { - return "", errors.New("EscapeString must be run with client_encoding=UTF8") - } - - return strings.Replace(s, "'", "''", -1), nil -} diff --git a/pgconn/pgconn_stress_test.go b/pgconn/pgconn_stress_test.go deleted file mode 100644 index 7a95fa98..00000000 --- a/pgconn/pgconn_stress_test.go +++ /dev/null @@ -1,120 +0,0 @@ -package pgconn_test - -import ( - "context" - "math/rand" - "os" - "strconv" - "testing" - "time" - - "github.com/jackc/pgx/pgconn" - - "github.com/stretchr/testify/require" -) - -func TestConnStress(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - actionCount := 100 - if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { - stressFactor, err := strconv.ParseInt(s, 10, 64) - require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") - actionCount *= int(stressFactor) - } - - setupStressDB(t, pgConn) - - actions := []struct { - name string - fn func(*pgconn.PgConn) error - }{ - {"Exec Select", stressExecSelect}, - {"ExecParams Select", stressExecParamsSelect}, - {"Batch", stressBatch}, - {"ExecCanceled", stressExecSelectCanceled}, - {"ExecParamsCanceled", stressExecParamsSelectCanceled}, - {"BatchCanceled", stressBatchCanceled}, - } - - for i := 0; i < actionCount; i++ { - action := actions[rand.Intn(len(actions))] - err := action.fn(pgConn) - require.Nilf(t, err, "%d: %s", i, action.name) - } -} - -func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { - _, err := pgConn.Exec(context.Background(), ` - create temporary table widgets( - id serial primary key, - name varchar not null, - description text, - creation_time timestamptz default now() - ); - - insert into widgets(name, description) values - ('Foo', 'bar'), - ('baz', 'Something really long Something really long Something really long Something really long Something really long'), - ('a', 'b')`).ReadAll() - require.NoError(t, err) -} - -func stressExecSelect(pgConn *pgconn.PgConn) error { - _, err := pgConn.Exec(context.Background(), "select * from widgets").ReadAll() - return err -} - -func stressExecParamsSelect(pgConn *pgconn.PgConn) error { - result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() - return result.Err -} - -func stressBatch(pgConn *pgconn.PgConn) error { - batch := &pgconn.Batch{} - - batch.ExecParams("select * from widgets", nil, nil, nil, nil) - batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - _, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - return err -} - -func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets").ReadAll() - cancel() - if err != context.DeadlineExceeded { - return err - } - - return nil -} - -func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() - cancel() - if result.Err != context.DeadlineExceeded { - return result.Err - } - - return nil -} - -func stressBatchCanceled(pgConn *pgconn.PgConn) error { - batch := &pgconn.Batch{} - batch.ExecParams("select * from widgets", nil, nil, nil, nil) - batch.ExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.ExecBatch(ctx, batch).ReadAll() - cancel() - if err != context.DeadlineExceeded { - return err - } - - return nil -} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go deleted file mode 100644 index dbf9b840..00000000 --- a/pgconn/pgconn_test.go +++ /dev/null @@ -1,1016 +0,0 @@ -package pgconn_test - -import ( - "bytes" - "compress/gzip" - "context" - "crypto/tls" - "fmt" - "io" - "io/ioutil" - "log" - "net" - "os" - "strconv" - "testing" - "time" - - "github.com/jackc/pgx/pgconn" - "github.com/jackc/pgx/pgproto3" - "github.com/pkg/errors" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestConnect(t *testing.T) { - tests := []struct { - name string - env string - }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, - {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, - {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - connString := os.Getenv(tt.env) - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", tt.env) - } - - conn, err := pgconn.Connect(context.Background(), connString) - require.NoError(t, err) - - closeConn(t, conn) - }) - } -} - -// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure -// connection. -func TestConnectTLS(t *testing.T) { - t.Parallel() - - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } - - conn, err := pgconn.Connect(context.Background(), connString) - require.NoError(t, err) - - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } - - closeConn(t, conn) -} - -func TestConnectInvalidUser(t *testing.T) { - t.Parallel() - - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } - - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) - - config.User = "pgxinvalidusertest" - - conn, err := pgconn.ConnectConfig(context.Background(), config) - if err == nil { - conn.Close(context.Background()) - t.Fatal("expected err but got none") - } - pgErr, ok := err.(*pgconn.PgError) - if !ok { - t.Fatalf("Expected to receive a PgError, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } -} - -func TestConnectWithConnectionRefused(t *testing.T) { - t.Parallel() - - // Presumably nothing is listening on 127.0.0.1:1 - conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") - if err == nil { - conn.Close(context.Background()) - t.Fatal("Expected error establishing connection to bad port") - } -} - -func TestConnectCustomDialer(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - dialed := false - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialed = true - return net.Dial(network, address) - } - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.True(t, dialed) - closeConn(t, conn) -} - -func TestConnectWithRuntimeParams(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, conn) - - result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - - result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "myschema", string(result.Rows[0][0])) -} - -func TestConnectWithFallback(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - // Prepend current primary config to fallbacks - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) - - // Make primary config bad - config.Host = "localhost" - config.Port = 1 // presumably nothing listening here - - // Prepend bad first fallback - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 1, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) -} - -func TestConnectWithAfterConnectFunc(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - dialCount := 0 - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount += 1 - return net.Dial(network, address) - } - - acceptConnCount := 0 - config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount += 1 - if acceptConnCount < 2 { - return errors.New("reject first conn") - } - return nil - } - - // Append current primary config to fallbacks - config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }) - - // Repeat fallbacks - config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) - - assert.True(t, dialCount > 1) - assert.True(t, acceptConnCount > 1) -} - -func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite - config.RuntimeParams["default_transaction_read_only"] = "on" - - conn, err := pgconn.ConnectConfig(context.Background(), config) - if !assert.NotNil(t, err) { - conn.Close(context.Background()) - } -} - -func TestConnPrepareFailure(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) - require.Nil(t, psd) - require.NotNil(t, err) - - ensureConnValid(t, pgConn) -} - -func TestConnExec(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.NoError(t, err) - - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - - ensureConnValid(t, pgConn) -} - -func TestConnExecEmpty(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - multiResult := pgConn.Exec(context.Background(), ";") - - resultCount := 0 - for multiResult.NextResult() { - resultCount += 1 - multiResult.ResultReader().Close() - } - assert.Equal(t, 0, resultCount) - err = multiResult.Close() - assert.NoError(t, err) - - ensureConnValid(t, pgConn) -} - -func TestConnExecMultipleQueries(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() - assert.NoError(t, err) - - assert.Len(t, results, 2) - - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - - assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - assert.Len(t, results[1].Rows, 1) - assert.Equal(t, "1", string(results[1].Rows[0][0])) - - ensureConnValid(t, pgConn) -} - -func TestConnExecMultipleQueriesError(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() - require.NotNil(t, err) - if pgErr, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "22012", pgErr.Code) - } else { - t.Errorf("unexpected error: %v", err) - } - - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) - - ensureConnValid(t, pgConn) -} - -func TestConnExecContextCanceled(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)") - - for multiResult.NextResult() { - } - err = multiResult.Close() - assert.Equal(t, context.DeadlineExceeded, err) - - ensureConnValid(t, pgConn) -} - -func TestConnExecParams(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) - - ensureConnValid(t, pgConn) -} - -func TestConnExecParamsCanceled(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - } - assert.Equal(t, 0, rowCount) - commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(""), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) - - ensureConnValid(t, pgConn) -} - -func TestConnExecPrepared(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, 1) - assert.Len(t, psd.Fields, 1) - - result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) - - ensureConnValid(t, pgConn) -} - -func TestConnExecPreparedCanceled(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) - require.NoError(t, err) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - } - assert.Equal(t, 0, rowCount) - commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(""), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) - - ensureConnValid(t, pgConn) -} - -func TestConnExecBatch(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.NoError(t, err) - - batch := &pgconn.Batch{} - - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) - batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) - results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, 3) - - require.Len(t, results[0].Rows, 1) - require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - - require.Len(t, results[1].Rows, 1) - require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - - require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) -} - -func TestCommandTag(t *testing.T) { - t.Parallel() - - var tests = []struct { - commandTag pgconn.CommandTag - rowsAffected int64 - }{ - {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5}, - {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1}, - {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1}, - {commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0}, - } - - for i, tt := range tests { - actual := tt.commandTag.RowsAffected() - assert.Equalf(t, tt.rowsAffected, actual, "%d. %v", i, tt.commandTag) - } -} - -func TestConnContextCancelWithOnContextCancel(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - calledChan := make(chan struct{}) - - config.OnContextCancel = func(cc *pgconn.ContextCancel) { - defer cc.Finish() - close(calledChan) - - for { - msg, err := cc.PgConn.ReceiveMessage() - if err != nil { - cc.PgConn.Close(context.Background()) - return - } - - switch msg.(type) { - case *pgproto3.ReadyForQuery: - return - } - } - } - - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - result := pgConn.ExecParams(ctx, "select 'Hello, world', pg_sleep(0.25)", nil, nil, nil, nil) - _, err = result.Close() - assert.Equal(t, context.DeadlineExceeded, err) - - called := false - select { - case <-calledChan: - called = true - case <-time.NewTimer(time.Second).C: - } - - assert.True(t, called) - - ensureConnValid(t, pgConn) -} - -func TestConnWaitUntilReady(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil).Read() - assert.Equal(t, context.DeadlineExceeded, result.Err) - - err = pgConn.WaitUntilReady(context.Background()) - require.NoError(t, err) - - ensureConnValid(t, pgConn) -} - -func TestConnOnNotice(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - var msg string - config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { - msg = notice.Message - } - - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) - - multiResult := pgConn.Exec(context.Background(), `do $$ -begin - raise notice 'hello, world'; -end$$;`) - err = multiResult.Close() - require.NoError(t, err) - assert.Equal(t, "hello, world", msg) - - ensureConnValid(t, pgConn) -} - -func TestConnOnNotification(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } - - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.NoError(t, err) - - notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.NoError(t, err) - - _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() - require.NoError(t, err) - - assert.Equal(t, "bar", msg) - - ensureConnValid(t, pgConn) -} - -func TestConnWaitForNotification(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } - - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.NoError(t, err) - - notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.NoError(t, err) - - err = pgConn.WaitForNotification(context.Background()) - require.NoError(t, err) - - assert.Equal(t, "bar", msg) - - ensureConnValid(t, pgConn) -} - -func TestConnWaitForNotificationTimeout(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - err = pgConn.WaitForNotification(ctx) - cancel() - require.Equal(t, context.DeadlineExceeded, err) - - ensureConnValid(t, pgConn) -} - -func TestConnCopyToSmall(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json - )`).ReadAll() - require.NoError(t, err) - - _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.NoError(t, err) - - _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.NoError(t, err) - - inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + - "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") - - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - - res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.NoError(t, err) - - assert.Equal(t, int64(2), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) - - ensureConnValid(t, pgConn) -} - -func TestConnCopyToLarge(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json, - h bytea - )`).ReadAll() - require.NoError(t, err) - - inputBytes := make([]byte, 0) - - for i := 0; i < 1000; i++ { - _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() - require.NoError(t, err) - inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) - } - - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - - res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.NoError(t, err) - - assert.Equal(t, int64(1000), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) - - ensureConnValid(t, pgConn) -} - -func TestConnCopyToQueryError(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - outputWriter := bytes.NewBuffer(make([]byte, 0)) - - res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) - - ensureConnValid(t, pgConn) -} - -func TestConnCopyToCanceled(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - outputWriter := &bytes.Buffer{} - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") - assert.Equal(t, context.DeadlineExceeded, err) - assert.Equal(t, pgconn.CommandTag(""), res) - - ensureConnValid(t, pgConn) -} - -func TestConnCopyFrom(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - srcBuf := &bytes.Buffer{} - - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - require.NoError(t, err) - } - - ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - - result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - assert.Equal(t, inputRows, result.Rows) - - ensureConnValid(t, pgConn) -} - -func TestConnCopyFromCanceled(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - r, w := io.Pipe() - go func() { - for i := 0; i < 1000000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - if err != nil { - return - } - time.Sleep(time.Microsecond) - } - }() - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") - cancel() - assert.Equal(t, int64(0), ct.RowsAffected()) - require.Equal(t, context.DeadlineExceeded, err) - - assert.False(t, pgConn.IsAlive()) -} - -func TestConnCopyFromGzipReader(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - f, err := ioutil.TempFile("", "*") - require.NoError(t, err) - - gw := gzip.NewWriter(f) - - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - require.NoError(t, err) - } - - err = gw.Close() - require.NoError(t, err) - - _, err = f.Seek(0, 0) - require.NoError(t, err) - - gr, err := gzip.NewReader(f) - require.NoError(t, err) - - ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - - err = gr.Close() - require.NoError(t, err) - - err = f.Close() - require.NoError(t, err) - - err = os.Remove(f.Name()) - require.NoError(t, err) - - result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - assert.Equal(t, inputRows, result.Rows) - - ensureConnValid(t, pgConn) -} - -func TestConnCopyFromQuerySyntaxError(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - srcBuf := &bytes.Buffer{} - - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) - - ensureConnValid(t, pgConn) -} - -func TestConnCopyFromQueryNoTableError(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - srcBuf := &bytes.Buffer{} - - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) - - ensureConnValid(t, pgConn) -} - -func TestConnEscapeString(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - tests := []struct { - in string - out string - }{ - {in: "", out: ""}, - {in: "42", out: "42"}, - {in: "'", out: "''"}, - {in: "hi'there", out: "hi''there"}, - {in: "'hi there'", out: "''hi there''"}, - } - - for i, tt := range tests { - value, err := pgConn.EscapeString(tt.in) - if assert.NoErrorf(t, err, "%d.", i) { - assert.Equalf(t, tt.out, value, "%d.", i) - } - } - - ensureConnValid(t, pgConn) -} - -func Example() { - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - log.Fatalln(err) - } - defer pgConn.Close(context.Background()) - - result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() - if result.Err != nil { - log.Fatalln(result.Err) - } - - for _, row := range result.Rows { - fmt.Println(string(row[0])) - } - - fmt.Println(result.CommandTag) - // Output: - // 1 - // 2 - // 3 - // SELECT 3 -} diff --git a/pgio/doc.go b/pgio/doc.go deleted file mode 100644 index ef2dcc7f..00000000 --- a/pgio/doc.go +++ /dev/null @@ -1,6 +0,0 @@ -// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. -/* -pgio provides functions for appending integers to a []byte while doing byte -order conversion. -*/ -package pgio diff --git a/pgio/write.go b/pgio/write.go deleted file mode 100644 index 96aedf9d..00000000 --- a/pgio/write.go +++ /dev/null @@ -1,40 +0,0 @@ -package pgio - -import "encoding/binary" - -func AppendUint16(buf []byte, n uint16) []byte { - wp := len(buf) - buf = append(buf, 0, 0) - binary.BigEndian.PutUint16(buf[wp:], n) - return buf -} - -func AppendUint32(buf []byte, n uint32) []byte { - wp := len(buf) - buf = append(buf, 0, 0, 0, 0) - binary.BigEndian.PutUint32(buf[wp:], n) - return buf -} - -func AppendUint64(buf []byte, n uint64) []byte { - wp := len(buf) - buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) - binary.BigEndian.PutUint64(buf[wp:], n) - return buf -} - -func AppendInt16(buf []byte, n int16) []byte { - return AppendUint16(buf, uint16(n)) -} - -func AppendInt32(buf []byte, n int32) []byte { - return AppendUint32(buf, uint32(n)) -} - -func AppendInt64(buf []byte, n int64) []byte { - return AppendUint64(buf, uint64(n)) -} - -func SetInt32(buf []byte, n int32) { - binary.BigEndian.PutUint32(buf, uint32(n)) -} diff --git a/pgio/write_test.go b/pgio/write_test.go deleted file mode 100644 index bd50e71c..00000000 --- a/pgio/write_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package pgio - -import ( - "reflect" - "testing" -) - -func TestAppendUint16NilBuf(t *testing.T) { - buf := AppendUint16(nil, 1) - if !reflect.DeepEqual(buf, []byte{0, 1}) { - t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) - } -} - -func TestAppendUint16EmptyBuf(t *testing.T) { - buf := []byte{} - buf = AppendUint16(buf, 1) - if !reflect.DeepEqual(buf, []byte{0, 1}) { - t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) - } -} - -func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) { - buf := make([]byte, 0, 4) - AppendUint16(buf, 1) - buf = buf[0:2] - if !reflect.DeepEqual(buf, []byte{0, 1}) { - t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) - } -} - -func TestAppendUint32NilBuf(t *testing.T) { - buf := AppendUint32(nil, 1) - if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { - t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) - } -} - -func TestAppendUint32EmptyBuf(t *testing.T) { - buf := []byte{} - buf = AppendUint32(buf, 1) - if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { - t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) - } -} - -func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) { - buf := make([]byte, 0, 4) - AppendUint32(buf, 1) - buf = buf[0:4] - if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { - t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) - } -} - -func TestAppendUint64NilBuf(t *testing.T) { - buf := AppendUint64(nil, 1) - if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { - t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) - } -} - -func TestAppendUint64EmptyBuf(t *testing.T) { - buf := []byte{} - buf = AppendUint64(buf, 1) - if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { - t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) - } -} - -func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) { - buf := make([]byte, 0, 8) - AppendUint64(buf, 1) - buf = buf[0:8] - if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { - t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) - } -} diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 4d15f7b8..2dde8609 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -7,7 +7,7 @@ import ( "github.com/pkg/errors" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -43,7 +43,7 @@ func (s *Server) ServeOne() error { s.Close() - backend, err := pgproto3.NewBackend(conn, conn) + backend, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn) if err != nil { conn.Close() return err diff --git a/pgpassfile/pgpass.go b/pgpassfile/pgpass.go deleted file mode 100644 index cd249bde..00000000 --- a/pgpassfile/pgpass.go +++ /dev/null @@ -1,109 +0,0 @@ -package pgpassfile - -import ( - "bufio" - "io" - "os" - "regexp" - "strings" -) - -// Entry represents a line in a PG passfile. -type Entry struct { - Hostname string - Port string - Database string - Username string - Password string -} - -// Passfile is the in memory data structure representing a PG passfile. -type Passfile struct { - Entries []*Entry -} - -// ReadPassfile reads the file at path and parses it into a Passfile. -func ReadPassfile(path string) (*Passfile, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - - return ParsePassfile(f) -} - -// ParsePassfile reads r and parses it into a Passfile. -func ParsePassfile(r io.Reader) (*Passfile, error) { - passfile := &Passfile{} - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - entry := parseLine(scanner.Text()) - if entry != nil { - passfile.Entries = append(passfile.Entries, entry) - } - } - - return passfile, scanner.Err() -} - -// Match (not colons or escaped colon or escaped backslash)+. Essentially gives a split on unescaped -// colon. -var colonSplitterRegexp = regexp.MustCompile("(([^:]|(\\:)))+") - -// var colonSplitterRegexp = regexp.MustCompile("((?:[^:]|(?:\\:)|(?:\\\\))+)") - -// parseLine parses a line into an *Entry. It returns nil on comment lines or any other unparsable -// line. -func parseLine(line string) *Entry { - const ( - tmpBackslash = "\r" - tmpColon = "\n" - ) - - line = strings.TrimSpace(line) - - if strings.HasPrefix(line, "#") { - return nil - } - - line = strings.Replace(line, `\\`, tmpBackslash, -1) - line = strings.Replace(line, `\:`, tmpColon, -1) - - parts := strings.Split(line, ":") - if len(parts) != 5 { - return nil - } - - // Unescape escaped colons and backslashes - for i := range parts { - parts[i] = strings.Replace(parts[i], tmpBackslash, `\`, -1) - parts[i] = strings.Replace(parts[i], tmpColon, `:`, -1) - } - - return &Entry{ - Hostname: parts[0], - Port: parts[1], - Database: parts[2], - Username: parts[3], - Password: parts[4], - } -} - -// FindPassword finds the password for the provided hostname, port, database, and username. For a -// Unix domain socket hostname must be set to "localhost". An empty string will be returned if no -// match is found. -// -// See https://www.postgresql.org/docs/current/libpq-pgpass.html for more password file information. -func (pf *Passfile) FindPassword(hostname, port, database, username string) (password string) { - for _, e := range pf.Entries { - if (e.Hostname == "*" || e.Hostname == hostname) && - (e.Port == "*" || e.Port == port) && - (e.Database == "*" || e.Database == database) && - (e.Username == "*" || e.Username == username) { - return e.Password - } - } - return "" -} diff --git a/pgpassfile/pgpass_test.go b/pgpassfile/pgpass_test.go deleted file mode 100644 index adf7f2af..00000000 --- a/pgpassfile/pgpass_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package pgpassfile - -import ( - "bytes" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func unescape(s string) string { - s = strings.Replace(s, `\:`, `:`, -1) - s = strings.Replace(s, `\\`, `\`, -1) - return s -} - -var passfile = [][]string{ - {"test1", "5432", "larrydb", "larry", "whatstheidea"}, - {"test1", "5432", "moedb", "moe", "imbecile"}, - {"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"}, - {"test2", "5432", "*", "shemp", "heymoe"}, - {"test2", "5432", "*", "*", `test\\ing\:`}, - {"localhost", "*", "*", "*", "sesam"}, - {"test3", "*", "", "", "swordfish"}, // user will be filled later -} - -func TestParsePassFile(t *testing.T) { - buf := bytes.NewBufferString(`# A comment - test1:5432:larrydb:larry:whatstheidea - test1:5432:moedb:moe:imbecile - test1:5432:curlydb:curly:nyuknyuknyuk - test2:5432:*:shemp:heymoe - test2:5432:*:*:test\\ing\: - localhost:*:*:*:sesam - `) - - passfile, err := ParsePassfile(buf) - require.Nil(t, err) - - assert.Len(t, passfile.Entries, 6) - - assert.Equal(t, "whatstheidea", passfile.FindPassword("test1", "5432", "larrydb", "larry")) - assert.Equal(t, "imbecile", passfile.FindPassword("test1", "5432", "moedb", "moe")) - assert.Equal(t, `test\ing:`, passfile.FindPassword("test2", "5432", "something", "else")) - assert.Equal(t, "sesam", passfile.FindPassword("localhost", "9999", "foo", "bare")) - - assert.Equal(t, "", passfile.FindPassword("wrong", "5432", "larrydb", "larry")) - assert.Equal(t, "", passfile.FindPassword("test1", "wrong", "larrydb", "larry")) - assert.Equal(t, "", passfile.FindPassword("test1", "5432", "wrong", "larry")) - assert.Equal(t, "", passfile.FindPassword("test1", "5432", "larrydb", "wrong")) -} diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go deleted file mode 100644 index 77750b86..00000000 --- a/pgproto3/authentication.go +++ /dev/null @@ -1,54 +0,0 @@ -package pgproto3 - -import ( - "encoding/binary" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -const ( - AuthTypeOk = 0 - AuthTypeCleartextPassword = 3 - AuthTypeMD5Password = 5 -) - -type Authentication struct { - Type uint32 - - // MD5Password fields - Salt [4]byte -} - -func (*Authentication) Backend() {} - -func (dst *Authentication) Decode(src []byte) error { - *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} - - switch dst.Type { - case AuthTypeOk: - case AuthTypeCleartextPassword: - case AuthTypeMD5Password: - copy(dst.Salt[:], src[4:8]) - default: - return errors.Errorf("unknown authentication type: %d", dst.Type) - } - - return nil -} - -func (src *Authentication) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - dst = pgio.AppendUint32(dst, src.Type) - - switch src.Type { - case AuthTypeMD5Password: - dst = append(dst, src.Salt[:]...) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} diff --git a/pgproto3/backend.go b/pgproto3/backend.go deleted file mode 100644 index ea44d1d1..00000000 --- a/pgproto3/backend.go +++ /dev/null @@ -1,113 +0,0 @@ -package pgproto3 - -import ( - "encoding/binary" - "io" - - "github.com/jackc/pgx/chunkreader" - "github.com/pkg/errors" -) - -type Backend struct { - cr *chunkreader.ChunkReader - w io.Writer - - // Frontend message flyweights - bind Bind - _close Close - copyFail CopyFail - describe Describe - execute Execute - flush Flush - parse Parse - passwordMessage PasswordMessage - query Query - startupMessage StartupMessage - sync Sync - terminate Terminate - - bodyLen int - msgType byte - partialMsg bool -} - -func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { - cr := chunkreader.NewChunkReader(r) - return &Backend{cr: cr, w: w}, nil -} - -func (b *Backend) Send(msg BackendMessage) error { - _, err := b.w.Write(msg.Encode(nil)) - return err -} - -func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { - buf, err := b.cr.Next(4) - if err != nil { - return nil, err - } - msgSize := int(binary.BigEndian.Uint32(buf) - 4) - - buf, err = b.cr.Next(msgSize) - if err != nil { - return nil, err - } - - err = b.startupMessage.Decode(buf) - if err != nil { - return nil, err - } - - return &b.startupMessage, nil -} - -func (b *Backend) Receive() (FrontendMessage, error) { - if !b.partialMsg { - header, err := b.cr.Next(5) - if err != nil { - return nil, err - } - - b.msgType = header[0] - b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 - b.partialMsg = true - } - - var msg FrontendMessage - switch b.msgType { - case 'B': - msg = &b.bind - case 'C': - msg = &b._close - case 'D': - msg = &b.describe - case 'E': - msg = &b.execute - case 'f': - msg = &b.copyFail - case 'H': - msg = &b.flush - case 'P': - msg = &b.parse - case 'p': - msg = &b.passwordMessage - case 'Q': - msg = &b.query - case 'S': - msg = &b.sync - case 'X': - msg = &b.terminate - default: - return nil, errors.Errorf("unknown message type: %c", b.msgType) - } - - msgBody, err := b.cr.Next(b.bodyLen) - if err != nil { - return nil, err - } - - b.partialMsg = false - - err = msg.Decode(msgBody) - return msg, err -} diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go deleted file mode 100644 index 5a478f10..00000000 --- a/pgproto3/backend_key_data.go +++ /dev/null @@ -1,46 +0,0 @@ -package pgproto3 - -import ( - "encoding/binary" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type BackendKeyData struct { - ProcessID uint32 - SecretKey uint32 -} - -func (*BackendKeyData) Backend() {} - -func (dst *BackendKeyData) Decode(src []byte) error { - if len(src) != 8 { - return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} - } - - dst.ProcessID = binary.BigEndian.Uint32(src[:4]) - dst.SecretKey = binary.BigEndian.Uint32(src[4:]) - - return nil -} - -func (src *BackendKeyData) Encode(dst []byte) []byte { - dst = append(dst, 'K') - dst = pgio.AppendUint32(dst, 12) - dst = pgio.AppendUint32(dst, src.ProcessID) - dst = pgio.AppendUint32(dst, src.SecretKey) - return dst -} - -func (src *BackendKeyData) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - ProcessID uint32 - SecretKey uint32 - }{ - Type: "BackendKeyData", - ProcessID: src.ProcessID, - SecretKey: src.SecretKey, - }) -} diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go deleted file mode 100644 index 02a5e9ca..00000000 --- a/pgproto3/backend_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package pgproto3_test - -import ( - "testing" - - "github.com/jackc/pgx/pgproto3" -) - -func TestBackendReceiveInterrupted(t *testing.T) { - t.Parallel() - - server := &interruptReader{} - server.push([]byte{'Q', 0, 0, 0, 6}) - - backend, err := pgproto3.NewBackend(server, nil) - if err != nil { - t.Fatal(err) - } - - msg, err := backend.Receive() - if err == nil { - t.Fatal("expected err") - } - if msg != nil { - t.Fatalf("did not expect msg, but %v", msg) - } - - server.push([]byte{'I', 0}) - - msg, err = backend.Receive() - if err != nil { - t.Fatal(err) - } - if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" { - t.Fatalf("unexpected msg: %v", msg) - } -} diff --git a/pgproto3/big_endian.go b/pgproto3/big_endian.go deleted file mode 100644 index f7bdb97e..00000000 --- a/pgproto3/big_endian.go +++ /dev/null @@ -1,37 +0,0 @@ -package pgproto3 - -import ( - "encoding/binary" -) - -type BigEndianBuf [8]byte - -func (b BigEndianBuf) Int16(n int16) []byte { - buf := b[0:2] - binary.BigEndian.PutUint16(buf, uint16(n)) - return buf -} - -func (b BigEndianBuf) Uint16(n uint16) []byte { - buf := b[0:2] - binary.BigEndian.PutUint16(buf, n) - return buf -} - -func (b BigEndianBuf) Int32(n int32) []byte { - buf := b[0:4] - binary.BigEndian.PutUint32(buf, uint32(n)) - return buf -} - -func (b BigEndianBuf) Uint32(n uint32) []byte { - buf := b[0:4] - binary.BigEndian.PutUint32(buf, n) - return buf -} - -func (b BigEndianBuf) Int64(n int64) []byte { - buf := b[0:8] - binary.BigEndian.PutUint64(buf, uint64(n)) - return buf -} diff --git a/pgproto3/bind.go b/pgproto3/bind.go deleted file mode 100644 index cceee6ab..00000000 --- a/pgproto3/bind.go +++ /dev/null @@ -1,171 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "encoding/hex" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type Bind struct { - DestinationPortal string - PreparedStatement string - ParameterFormatCodes []int16 - Parameters [][]byte - ResultFormatCodes []int16 -} - -func (*Bind) Frontend() {} - -func (dst *Bind) Decode(src []byte) error { - *dst = Bind{} - - idx := bytes.IndexByte(src, 0) - if idx < 0 { - return &invalidMessageFormatErr{messageType: "Bind"} - } - dst.DestinationPortal = string(src[:idx]) - rp := idx + 1 - - idx = bytes.IndexByte(src[rp:], 0) - if idx < 0 { - return &invalidMessageFormatErr{messageType: "Bind"} - } - dst.PreparedStatement = string(src[rp : rp+idx]) - rp += idx + 1 - - if len(src[rp:]) < 2 { - return &invalidMessageFormatErr{messageType: "Bind"} - } - parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) - rp += 2 - - if parameterFormatCodeCount > 0 { - dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) - - if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { - return &invalidMessageFormatErr{messageType: "Bind"} - } - for i := 0; i < parameterFormatCodeCount; i++ { - dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) - rp += 2 - } - } - - if len(src[rp:]) < 2 { - return &invalidMessageFormatErr{messageType: "Bind"} - } - parameterCount := int(binary.BigEndian.Uint16(src[rp:])) - rp += 2 - - if parameterCount > 0 { - dst.Parameters = make([][]byte, parameterCount) - - for i := 0; i < parameterCount; i++ { - if len(src[rp:]) < 4 { - return &invalidMessageFormatErr{messageType: "Bind"} - } - - msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - - // null - if msgSize == -1 { - continue - } - - if len(src[rp:]) < msgSize { - return &invalidMessageFormatErr{messageType: "Bind"} - } - - dst.Parameters[i] = src[rp : rp+msgSize] - rp += msgSize - } - } - - if len(src[rp:]) < 2 { - return &invalidMessageFormatErr{messageType: "Bind"} - } - resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) - rp += 2 - - dst.ResultFormatCodes = make([]int16, resultFormatCodeCount) - if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { - return &invalidMessageFormatErr{messageType: "Bind"} - } - for i := 0; i < resultFormatCodeCount; i++ { - dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) - rp += 2 - } - - return nil -} - -func (src *Bind) Encode(dst []byte) []byte { - dst = append(dst, 'B') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = append(dst, src.DestinationPortal...) - dst = append(dst, 0) - dst = append(dst, src.PreparedStatement...) - dst = append(dst, 0) - - dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) - for _, fc := range src.ParameterFormatCodes { - dst = pgio.AppendInt16(dst, fc) - } - - dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) - for _, p := range src.Parameters { - if p == nil { - dst = pgio.AppendInt32(dst, -1) - continue - } - - dst = pgio.AppendInt32(dst, int32(len(p))) - dst = append(dst, p...) - } - - dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) - for _, fc := range src.ResultFormatCodes { - dst = pgio.AppendInt16(dst, fc) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *Bind) MarshalJSON() ([]byte, error) { - formattedParameters := make([]map[string]string, len(src.Parameters)) - for i, p := range src.Parameters { - if p == nil { - continue - } - - if src.ParameterFormatCodes[i] == 0 { - formattedParameters[i] = map[string]string{"text": string(p)} - } else { - formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} - } - } - - return json.Marshal(struct { - Type string - DestinationPortal string - PreparedStatement string - ParameterFormatCodes []int16 - Parameters []map[string]string - ResultFormatCodes []int16 - }{ - Type: "Bind", - DestinationPortal: src.DestinationPortal, - PreparedStatement: src.PreparedStatement, - ParameterFormatCodes: src.ParameterFormatCodes, - Parameters: formattedParameters, - ResultFormatCodes: src.ResultFormatCodes, - }) -} diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go deleted file mode 100644 index 60360519..00000000 --- a/pgproto3/bind_complete.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgproto3 - -import ( - "encoding/json" -) - -type BindComplete struct{} - -func (*BindComplete) Backend() {} - -func (dst *BindComplete) Decode(src []byte) error { - if len(src) != 0 { - return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} - } - - return nil -} - -func (src *BindComplete) Encode(dst []byte) []byte { - return append(dst, '2', 0, 0, 0, 4) -} - -func (src *BindComplete) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - }{ - Type: "BindComplete", - }) -} diff --git a/pgproto3/close.go b/pgproto3/close.go deleted file mode 100644 index 5ff4c886..00000000 --- a/pgproto3/close.go +++ /dev/null @@ -1,59 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type Close struct { - ObjectType byte // 'S' = prepared statement, 'P' = portal - Name string -} - -func (*Close) Frontend() {} - -func (dst *Close) Decode(src []byte) error { - if len(src) < 2 { - return &invalidMessageFormatErr{messageType: "Close"} - } - - dst.ObjectType = src[0] - rp := 1 - - idx := bytes.IndexByte(src[rp:], 0) - if idx != len(src[rp:])-1 { - return &invalidMessageFormatErr{messageType: "Close"} - } - - dst.Name = string(src[rp : len(src)-1]) - - return nil -} - -func (src *Close) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = append(dst, src.ObjectType) - dst = append(dst, src.Name...) - dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *Close) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - ObjectType string - Name string - }{ - Type: "Close", - ObjectType: string(src.ObjectType), - Name: src.Name, - }) -} diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go deleted file mode 100644 index db793c94..00000000 --- a/pgproto3/close_complete.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgproto3 - -import ( - "encoding/json" -) - -type CloseComplete struct{} - -func (*CloseComplete) Backend() {} - -func (dst *CloseComplete) Decode(src []byte) error { - if len(src) != 0 { - return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} - } - - return nil -} - -func (src *CloseComplete) Encode(dst []byte) []byte { - return append(dst, '3', 0, 0, 0, 4) -} - -func (src *CloseComplete) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - }{ - Type: "CloseComplete", - }) -} diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go deleted file mode 100644 index 85848532..00000000 --- a/pgproto3/command_complete.go +++ /dev/null @@ -1,48 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type CommandComplete struct { - CommandTag string -} - -func (*CommandComplete) Backend() {} - -func (dst *CommandComplete) Decode(src []byte) error { - idx := bytes.IndexByte(src, 0) - if idx != len(src)-1 { - return &invalidMessageFormatErr{messageType: "CommandComplete"} - } - - dst.CommandTag = string(src[:idx]) - - return nil -} - -func (src *CommandComplete) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = append(dst, src.CommandTag...) - dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *CommandComplete) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - CommandTag string - }{ - Type: "CommandComplete", - CommandTag: src.CommandTag, - }) -} diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go deleted file mode 100644 index 2862a34f..00000000 --- a/pgproto3/copy_both_response.go +++ /dev/null @@ -1,65 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type CopyBothResponse struct { - OverallFormat byte - ColumnFormatCodes []uint16 -} - -func (*CopyBothResponse) Backend() {} - -func (dst *CopyBothResponse) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 3 { - return &invalidMessageFormatErr{messageType: "CopyBothResponse"} - } - - overallFormat := buf.Next(1)[0] - - columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) - if buf.Len() != columnCount*2 { - return &invalidMessageFormatErr{messageType: "CopyBothResponse"} - } - - columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { - columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) - } - - *dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} - - return nil -} - -func (src *CopyBothResponse) Encode(dst []byte) []byte { - dst = append(dst, 'W') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) - for _, fc := range src.ColumnFormatCodes { - dst = pgio.AppendUint16(dst, fc) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - ColumnFormatCodes []uint16 - }{ - Type: "CopyBothResponse", - ColumnFormatCodes: src.ColumnFormatCodes, - }) -} diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go deleted file mode 100644 index fab139e6..00000000 --- a/pgproto3/copy_data.go +++ /dev/null @@ -1,37 +0,0 @@ -package pgproto3 - -import ( - "encoding/hex" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type CopyData struct { - Data []byte -} - -func (*CopyData) Backend() {} -func (*CopyData) Frontend() {} - -func (dst *CopyData) Decode(src []byte) error { - dst.Data = src - return nil -} - -func (src *CopyData) Encode(dst []byte) []byte { - dst = append(dst, 'd') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) - dst = append(dst, src.Data...) - return dst -} - -func (src *CopyData) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - Data string - }{ - Type: "CopyData", - Data: hex.EncodeToString(src.Data), - }) -} diff --git a/pgproto3/copy_done.go b/pgproto3/copy_done.go deleted file mode 100644 index 92481908..00000000 --- a/pgproto3/copy_done.go +++ /dev/null @@ -1,30 +0,0 @@ -package pgproto3 - -import ( - "encoding/json" -) - -type CopyDone struct { -} - -func (*CopyDone) Backend() {} - -func (dst *CopyDone) Decode(src []byte) error { - if len(src) != 0 { - return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)} - } - - return nil -} - -func (src *CopyDone) Encode(dst []byte) []byte { - return append(dst, 'c', 0, 0, 0, 4) -} - -func (src *CopyDone) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - }{ - Type: "CopyDone", - }) -} diff --git a/pgproto3/copy_fail.go b/pgproto3/copy_fail.go deleted file mode 100644 index 432a311b..00000000 --- a/pgproto3/copy_fail.go +++ /dev/null @@ -1,49 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type CopyFail struct { - Error string -} - -func (*CopyFail) Frontend() {} -func (*CopyFail) Backend() {} - -func (dst *CopyFail) Decode(src []byte) error { - idx := bytes.IndexByte(src, 0) - if idx != len(src)-1 { - return &invalidMessageFormatErr{messageType: "CopyFail"} - } - - dst.Error = string(src[:idx]) - - return nil -} - -func (src *CopyFail) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = append(dst, src.Error...) - dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *CopyFail) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - Error string - }{ - Type: "CopyFail", - Error: src.Error, - }) -} diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go deleted file mode 100644 index 54083cd6..00000000 --- a/pgproto3/copy_in_response.go +++ /dev/null @@ -1,65 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type CopyInResponse struct { - OverallFormat byte - ColumnFormatCodes []uint16 -} - -func (*CopyInResponse) Backend() {} - -func (dst *CopyInResponse) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 3 { - return &invalidMessageFormatErr{messageType: "CopyInResponse"} - } - - overallFormat := buf.Next(1)[0] - - columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) - if buf.Len() != columnCount*2 { - return &invalidMessageFormatErr{messageType: "CopyInResponse"} - } - - columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { - columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) - } - - *dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} - - return nil -} - -func (src *CopyInResponse) Encode(dst []byte) []byte { - dst = append(dst, 'G') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) - for _, fc := range src.ColumnFormatCodes { - dst = pgio.AppendUint16(dst, fc) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *CopyInResponse) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - ColumnFormatCodes []uint16 - }{ - Type: "CopyInResponse", - ColumnFormatCodes: src.ColumnFormatCodes, - }) -} diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go deleted file mode 100644 index eaa33b8b..00000000 --- a/pgproto3/copy_out_response.go +++ /dev/null @@ -1,65 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type CopyOutResponse struct { - OverallFormat byte - ColumnFormatCodes []uint16 -} - -func (*CopyOutResponse) Backend() {} - -func (dst *CopyOutResponse) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 3 { - return &invalidMessageFormatErr{messageType: "CopyOutResponse"} - } - - overallFormat := buf.Next(1)[0] - - columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) - if buf.Len() != columnCount*2 { - return &invalidMessageFormatErr{messageType: "CopyOutResponse"} - } - - columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { - columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) - } - - *dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} - - return nil -} - -func (src *CopyOutResponse) Encode(dst []byte) []byte { - dst = append(dst, 'H') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) - for _, fc := range src.ColumnFormatCodes { - dst = pgio.AppendUint16(dst, fc) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - ColumnFormatCodes []uint16 - }{ - Type: "CopyOutResponse", - ColumnFormatCodes: src.ColumnFormatCodes, - }) -} diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go deleted file mode 100644 index e46d3cc0..00000000 --- a/pgproto3/data_row.go +++ /dev/null @@ -1,112 +0,0 @@ -package pgproto3 - -import ( - "encoding/binary" - "encoding/hex" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type DataRow struct { - Values [][]byte -} - -func (*DataRow) Backend() {} - -func (dst *DataRow) Decode(src []byte) error { - if len(src) < 2 { - return &invalidMessageFormatErr{messageType: "DataRow"} - } - rp := 0 - fieldCount := int(binary.BigEndian.Uint16(src[rp:])) - rp += 2 - - // If the capacity of the values slice is too small OR substantially too - // large reallocate. This is too avoid one row with many columns from - // permanently allocating memory. - if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { - newCap := 32 - if newCap < fieldCount { - newCap = fieldCount - } - dst.Values = make([][]byte, fieldCount, newCap) - } else { - dst.Values = dst.Values[:fieldCount] - } - - for i := 0; i < fieldCount; i++ { - if len(src[rp:]) < 4 { - return &invalidMessageFormatErr{messageType: "DataRow"} - } - - msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - - // null - if msgSize == -1 { - dst.Values[i] = nil - } else { - if len(src[rp:]) < msgSize { - return &invalidMessageFormatErr{messageType: "DataRow"} - } - - dst.Values[i] = src[rp : rp+msgSize] - rp += msgSize - } - } - - return nil -} - -func (src *DataRow) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = pgio.AppendUint16(dst, uint16(len(src.Values))) - for _, v := range src.Values { - if v == nil { - dst = pgio.AppendInt32(dst, -1) - continue - } - - dst = pgio.AppendInt32(dst, int32(len(v))) - dst = append(dst, v...) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *DataRow) MarshalJSON() ([]byte, error) { - formattedValues := make([]map[string]string, len(src.Values)) - for i, v := range src.Values { - if v == nil { - continue - } - - var hasNonPrintable bool - for _, b := range v { - if b < 32 { - hasNonPrintable = true - break - } - } - - if hasNonPrintable { - formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)} - } else { - formattedValues[i] = map[string]string{"text": string(v)} - } - } - - return json.Marshal(struct { - Type string - Values []map[string]string - }{ - Type: "DataRow", - Values: formattedValues, - }) -} diff --git a/pgproto3/describe.go b/pgproto3/describe.go deleted file mode 100644 index bb7bc056..00000000 --- a/pgproto3/describe.go +++ /dev/null @@ -1,59 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type Describe struct { - ObjectType byte // 'S' = prepared statement, 'P' = portal - Name string -} - -func (*Describe) Frontend() {} - -func (dst *Describe) Decode(src []byte) error { - if len(src) < 2 { - return &invalidMessageFormatErr{messageType: "Describe"} - } - - dst.ObjectType = src[0] - rp := 1 - - idx := bytes.IndexByte(src[rp:], 0) - if idx != len(src[rp:])-1 { - return &invalidMessageFormatErr{messageType: "Describe"} - } - - dst.Name = string(src[rp : len(src)-1]) - - return nil -} - -func (src *Describe) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = append(dst, src.ObjectType) - dst = append(dst, src.Name...) - dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *Describe) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - ObjectType string - Name string - }{ - Type: "Describe", - ObjectType: string(src.ObjectType), - Name: src.Name, - }) -} diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go deleted file mode 100644 index d283b06d..00000000 --- a/pgproto3/empty_query_response.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgproto3 - -import ( - "encoding/json" -) - -type EmptyQueryResponse struct{} - -func (*EmptyQueryResponse) Backend() {} - -func (dst *EmptyQueryResponse) Decode(src []byte) error { - if len(src) != 0 { - return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} - } - - return nil -} - -func (src *EmptyQueryResponse) Encode(dst []byte) []byte { - return append(dst, 'I', 0, 0, 0, 4) -} - -func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - }{ - Type: "EmptyQueryResponse", - }) -} diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go deleted file mode 100644 index 987fe38a..00000000 --- a/pgproto3/error_response.go +++ /dev/null @@ -1,214 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "strconv" -) - -type ErrorResponse struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string - - UnknownFields map[byte]string -} - -func (*ErrorResponse) Backend() {} - -func (dst *ErrorResponse) Decode(src []byte) error { - *dst = ErrorResponse{} - - buf := bytes.NewBuffer(src) - - for { - k, err := buf.ReadByte() - if err != nil { - return err - } - if k == 0 { - break - } - - vb, err := buf.ReadBytes(0) - if err != nil { - return err - } - v := string(vb[:len(vb)-1]) - - switch k { - case 'S': - dst.Severity = v - case 'C': - dst.Code = v - case 'M': - dst.Message = v - case 'D': - dst.Detail = v - case 'H': - dst.Hint = v - case 'P': - s := v - n, _ := strconv.ParseInt(s, 10, 32) - dst.Position = int32(n) - case 'p': - s := v - n, _ := strconv.ParseInt(s, 10, 32) - dst.InternalPosition = int32(n) - case 'q': - dst.InternalQuery = v - case 'W': - dst.Where = v - case 's': - dst.SchemaName = v - case 't': - dst.TableName = v - case 'c': - dst.ColumnName = v - case 'd': - dst.DataTypeName = v - case 'n': - dst.ConstraintName = v - case 'F': - dst.File = v - case 'L': - s := v - n, _ := strconv.ParseInt(s, 10, 32) - dst.Line = int32(n) - case 'R': - dst.Routine = v - - default: - if dst.UnknownFields == nil { - dst.UnknownFields = make(map[byte]string) - } - dst.UnknownFields[k] = v - } - } - - return nil -} - -func (src *ErrorResponse) Encode(dst []byte) []byte { - return append(dst, src.marshalBinary('E')...) -} - -func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte(typeByte) - buf.Write(bigEndian.Uint32(0)) - - if src.Severity != "" { - buf.WriteByte('S') - buf.WriteString(src.Severity) - buf.WriteByte(0) - } - if src.Code != "" { - buf.WriteByte('C') - buf.WriteString(src.Code) - buf.WriteByte(0) - } - if src.Message != "" { - buf.WriteByte('M') - buf.WriteString(src.Message) - buf.WriteByte(0) - } - if src.Detail != "" { - buf.WriteByte('D') - buf.WriteString(src.Detail) - buf.WriteByte(0) - } - if src.Hint != "" { - buf.WriteByte('H') - buf.WriteString(src.Hint) - buf.WriteByte(0) - } - if src.Position != 0 { - buf.WriteByte('P') - buf.WriteString(strconv.Itoa(int(src.Position))) - buf.WriteByte(0) - } - if src.InternalPosition != 0 { - buf.WriteByte('p') - buf.WriteString(strconv.Itoa(int(src.InternalPosition))) - buf.WriteByte(0) - } - if src.InternalQuery != "" { - buf.WriteByte('q') - buf.WriteString(src.InternalQuery) - buf.WriteByte(0) - } - if src.Where != "" { - buf.WriteByte('W') - buf.WriteString(src.Where) - buf.WriteByte(0) - } - if src.SchemaName != "" { - buf.WriteByte('s') - buf.WriteString(src.SchemaName) - buf.WriteByte(0) - } - if src.TableName != "" { - buf.WriteByte('t') - buf.WriteString(src.TableName) - buf.WriteByte(0) - } - if src.ColumnName != "" { - buf.WriteByte('c') - buf.WriteString(src.ColumnName) - buf.WriteByte(0) - } - if src.DataTypeName != "" { - buf.WriteByte('d') - buf.WriteString(src.DataTypeName) - buf.WriteByte(0) - } - if src.ConstraintName != "" { - buf.WriteByte('n') - buf.WriteString(src.ConstraintName) - buf.WriteByte(0) - } - if src.File != "" { - buf.WriteByte('F') - buf.WriteString(src.File) - buf.WriteByte(0) - } - if src.Line != 0 { - buf.WriteByte('L') - buf.WriteString(strconv.Itoa(int(src.Line))) - buf.WriteByte(0) - } - if src.Routine != "" { - buf.WriteByte('R') - buf.WriteString(src.Routine) - buf.WriteByte(0) - } - - for k, v := range src.UnknownFields { - buf.WriteByte(k) - buf.WriteByte(0) - buf.WriteString(v) - buf.WriteByte(0) - } - buf.WriteByte(0) - - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - - return buf.Bytes() -} diff --git a/pgproto3/execute.go b/pgproto3/execute.go deleted file mode 100644 index 76da9943..00000000 --- a/pgproto3/execute.go +++ /dev/null @@ -1,60 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type Execute struct { - Portal string - MaxRows uint32 -} - -func (*Execute) Frontend() {} - -func (dst *Execute) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - - b, err := buf.ReadBytes(0) - if err != nil { - return err - } - dst.Portal = string(b[:len(b)-1]) - - if buf.Len() < 4 { - return &invalidMessageFormatErr{messageType: "Execute"} - } - dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4)) - - return nil -} - -func (src *Execute) Encode(dst []byte) []byte { - dst = append(dst, 'E') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = append(dst, src.Portal...) - dst = append(dst, 0) - - dst = pgio.AppendUint32(dst, src.MaxRows) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *Execute) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - Portal string - MaxRows uint32 - }{ - Type: "Execute", - Portal: src.Portal, - MaxRows: src.MaxRows, - }) -} diff --git a/pgproto3/flush.go b/pgproto3/flush.go deleted file mode 100644 index 7fd5e987..00000000 --- a/pgproto3/flush.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgproto3 - -import ( - "encoding/json" -) - -type Flush struct{} - -func (*Flush) Frontend() {} - -func (dst *Flush) Decode(src []byte) error { - if len(src) != 0 { - return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)} - } - - return nil -} - -func (src *Flush) Encode(dst []byte) []byte { - return append(dst, 'H', 0, 0, 0, 4) -} - -func (src *Flush) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - }{ - Type: "Flush", - }) -} diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go deleted file mode 100644 index 31a955bc..00000000 --- a/pgproto3/frontend.go +++ /dev/null @@ -1,128 +0,0 @@ -package pgproto3 - -import ( - "encoding/binary" - "io" - - "github.com/jackc/pgx/chunkreader" - "github.com/pkg/errors" -) - -type Frontend struct { - cr *chunkreader.ChunkReader - w io.Writer - - // Backend message flyweights - authentication Authentication - backendKeyData BackendKeyData - bindComplete BindComplete - closeComplete CloseComplete - commandComplete CommandComplete - copyBothResponse CopyBothResponse - copyData CopyData - copyInResponse CopyInResponse - copyOutResponse CopyOutResponse - copyDone CopyDone - copyFail CopyFail - dataRow DataRow - emptyQueryResponse EmptyQueryResponse - errorResponse ErrorResponse - functionCallResponse FunctionCallResponse - noData NoData - noticeResponse NoticeResponse - notificationResponse NotificationResponse - parameterDescription ParameterDescription - parameterStatus ParameterStatus - parseComplete ParseComplete - readyForQuery ReadyForQuery - rowDescription RowDescription - - bodyLen int - msgType byte - partialMsg bool -} - -func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { - cr := chunkreader.NewChunkReader(r) - return &Frontend{cr: cr, w: w}, nil -} - -func (b *Frontend) Send(msg FrontendMessage) error { - _, err := b.w.Write(msg.Encode(nil)) - return err -} - -func (b *Frontend) Receive() (BackendMessage, error) { - if !b.partialMsg { - header, err := b.cr.Next(5) - if err != nil { - return nil, err - } - - b.msgType = header[0] - b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 - b.partialMsg = true - } - - var msg BackendMessage - switch b.msgType { - case '1': - msg = &b.parseComplete - case '2': - msg = &b.bindComplete - case '3': - msg = &b.closeComplete - case 'A': - msg = &b.notificationResponse - case 'c': - msg = &b.copyDone - case 'C': - msg = &b.commandComplete - case 'd': - msg = &b.copyData - case 'D': - msg = &b.dataRow - case 'E': - msg = &b.errorResponse - case 'f': - msg = &b.copyFail - case 'G': - msg = &b.copyInResponse - case 'H': - msg = &b.copyOutResponse - case 'I': - msg = &b.emptyQueryResponse - case 'K': - msg = &b.backendKeyData - case 'n': - msg = &b.noData - case 'N': - msg = &b.noticeResponse - case 'R': - msg = &b.authentication - case 'S': - msg = &b.parameterStatus - case 't': - msg = &b.parameterDescription - case 'T': - msg = &b.rowDescription - case 'V': - msg = &b.functionCallResponse - case 'W': - msg = &b.copyBothResponse - case 'Z': - msg = &b.readyForQuery - default: - return nil, errors.Errorf("unknown message type: %c", b.msgType) - } - - msgBody, err := b.cr.Next(b.bodyLen) - if err != nil { - return nil, err - } - - b.partialMsg = false - - err = msg.Decode(msgBody) - return msg, err -} diff --git a/pgproto3/frontend_test.go b/pgproto3/frontend_test.go deleted file mode 100644 index 7d6652c1..00000000 --- a/pgproto3/frontend_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package pgproto3_test - -import ( - "testing" - - "github.com/pkg/errors" - - "github.com/jackc/pgx/pgproto3" -) - -type interruptReader struct { - chunks [][]byte -} - -func (ir *interruptReader) Read(p []byte) (n int, err error) { - if len(ir.chunks) == 0 { - return 0, errors.New("no data") - } - - n = copy(p, ir.chunks[0]) - if n != len(ir.chunks[0]) { - panic("this test reader doesn't support partial reads of chunks") - } - - ir.chunks = ir.chunks[1:] - - return n, nil -} - -func (ir *interruptReader) push(p []byte) { - ir.chunks = append(ir.chunks, p) -} - -func TestFrontendReceiveInterrupted(t *testing.T) { - t.Parallel() - - server := &interruptReader{} - server.push([]byte{'Z', 0, 0, 0, 5}) - - frontend, err := pgproto3.NewFrontend(server, nil) - if err != nil { - t.Fatal(err) - } - - msg, err := frontend.Receive() - if err == nil { - t.Fatal("expected err") - } - if msg != nil { - t.Fatalf("did not expect msg, but %v", msg) - } - - server.push([]byte{'I'}) - - msg, err = frontend.Receive() - if err != nil { - t.Fatal(err) - } - if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' { - t.Fatalf("unexpected msg: %v", msg) - } -} diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go deleted file mode 100644 index bb325b69..00000000 --- a/pgproto3/function_call_response.go +++ /dev/null @@ -1,78 +0,0 @@ -package pgproto3 - -import ( - "encoding/binary" - "encoding/hex" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type FunctionCallResponse struct { - Result []byte -} - -func (*FunctionCallResponse) Backend() {} - -func (dst *FunctionCallResponse) Decode(src []byte) error { - if len(src) < 4 { - return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} - } - rp := 0 - resultSize := int(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - - if resultSize == -1 { - dst.Result = nil - return nil - } - - if len(src[rp:]) != resultSize { - return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} - } - - dst.Result = src[rp:] - return nil -} - -func (src *FunctionCallResponse) Encode(dst []byte) []byte { - dst = append(dst, 'V') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - if src.Result == nil { - dst = pgio.AppendInt32(dst, -1) - } else { - dst = pgio.AppendInt32(dst, int32(len(src.Result))) - dst = append(dst, src.Result...) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { - var formattedValue map[string]string - var hasNonPrintable bool - for _, b := range src.Result { - if b < 32 { - hasNonPrintable = true - break - } - } - - if hasNonPrintable { - formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)} - } else { - formattedValue = map[string]string{"text": string(src.Result)} - } - - return json.Marshal(struct { - Type string - Result map[string]string - }{ - Type: "FunctionCallResponse", - Result: formattedValue, - }) -} diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go deleted file mode 100644 index 1fb47c2a..00000000 --- a/pgproto3/no_data.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgproto3 - -import ( - "encoding/json" -) - -type NoData struct{} - -func (*NoData) Backend() {} - -func (dst *NoData) Decode(src []byte) error { - if len(src) != 0 { - return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} - } - - return nil -} - -func (src *NoData) Encode(dst []byte) []byte { - return append(dst, 'n', 0, 0, 0, 4) -} - -func (src *NoData) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - }{ - Type: "NoData", - }) -} diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go deleted file mode 100644 index e4595aa5..00000000 --- a/pgproto3/notice_response.go +++ /dev/null @@ -1,13 +0,0 @@ -package pgproto3 - -type NoticeResponse ErrorResponse - -func (*NoticeResponse) Backend() {} - -func (dst *NoticeResponse) Decode(src []byte) error { - return (*ErrorResponse)(dst).Decode(src) -} - -func (src *NoticeResponse) Encode(dst []byte) []byte { - return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) -} diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go deleted file mode 100644 index b14007b4..00000000 --- a/pgproto3/notification_response.go +++ /dev/null @@ -1,67 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type NotificationResponse struct { - PID uint32 - Channel string - Payload string -} - -func (*NotificationResponse) Backend() {} - -func (dst *NotificationResponse) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - - pid := binary.BigEndian.Uint32(buf.Next(4)) - - b, err := buf.ReadBytes(0) - if err != nil { - return err - } - channel := string(b[:len(b)-1]) - - b, err = buf.ReadBytes(0) - if err != nil { - return err - } - payload := string(b[:len(b)-1]) - - *dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload} - return nil -} - -func (src *NotificationResponse) Encode(dst []byte) []byte { - dst = append(dst, 'A') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = append(dst, src.Channel...) - dst = append(dst, 0) - dst = append(dst, src.Payload...) - dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *NotificationResponse) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - PID uint32 - Channel string - Payload string - }{ - Type: "NotificationResponse", - PID: src.PID, - Channel: src.Channel, - Payload: src.Payload, - }) -} diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go deleted file mode 100644 index 1fa3c927..00000000 --- a/pgproto3/parameter_description.go +++ /dev/null @@ -1,61 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type ParameterDescription struct { - ParameterOIDs []uint32 -} - -func (*ParameterDescription) Backend() {} - -func (dst *ParameterDescription) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 2 { - return &invalidMessageFormatErr{messageType: "ParameterDescription"} - } - - // Reported parameter count will be incorrect when number of args is greater than uint16 - buf.Next(2) - // Instead infer parameter count by remaining size of message - parameterCount := buf.Len() / 4 - - *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} - - for i := 0; i < parameterCount; i++ { - dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) - } - - return nil -} - -func (src *ParameterDescription) Encode(dst []byte) []byte { - dst = append(dst, 't') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) - for _, oid := range src.ParameterOIDs { - dst = pgio.AppendUint32(dst, oid) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *ParameterDescription) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - ParameterOIDs []uint32 - }{ - Type: "ParameterDescription", - ParameterOIDs: src.ParameterOIDs, - }) -} diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go deleted file mode 100644 index b3bac33f..00000000 --- a/pgproto3/parameter_status.go +++ /dev/null @@ -1,61 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type ParameterStatus struct { - Name string - Value string -} - -func (*ParameterStatus) Backend() {} - -func (dst *ParameterStatus) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - - b, err := buf.ReadBytes(0) - if err != nil { - return err - } - name := string(b[:len(b)-1]) - - b, err = buf.ReadBytes(0) - if err != nil { - return err - } - value := string(b[:len(b)-1]) - - *dst = ParameterStatus{Name: name, Value: value} - return nil -} - -func (src *ParameterStatus) Encode(dst []byte) []byte { - dst = append(dst, 'S') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = append(dst, src.Name...) - dst = append(dst, 0) - dst = append(dst, src.Value...) - dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - Name string - Value string - }{ - Type: "ParameterStatus", - Name: ps.Name, - Value: ps.Value, - }) -} diff --git a/pgproto3/parse.go b/pgproto3/parse.go deleted file mode 100644 index ca4834c6..00000000 --- a/pgproto3/parse.go +++ /dev/null @@ -1,83 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type Parse struct { - Name string - Query string - ParameterOIDs []uint32 -} - -func (*Parse) Frontend() {} - -func (dst *Parse) Decode(src []byte) error { - *dst = Parse{} - - buf := bytes.NewBuffer(src) - - b, err := buf.ReadBytes(0) - if err != nil { - return err - } - dst.Name = string(b[:len(b)-1]) - - b, err = buf.ReadBytes(0) - if err != nil { - return err - } - dst.Query = string(b[:len(b)-1]) - - if buf.Len() < 2 { - return &invalidMessageFormatErr{messageType: "Parse"} - } - parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) - - for i := 0; i < parameterOIDCount; i++ { - if buf.Len() < 4 { - return &invalidMessageFormatErr{messageType: "Parse"} - } - dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4))) - } - - return nil -} - -func (src *Parse) Encode(dst []byte) []byte { - dst = append(dst, 'P') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = append(dst, src.Name...) - dst = append(dst, 0) - dst = append(dst, src.Query...) - dst = append(dst, 0) - - dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) - for _, oid := range src.ParameterOIDs { - dst = pgio.AppendUint32(dst, oid) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *Parse) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - Name string - Query string - ParameterOIDs []uint32 - }{ - Type: "Parse", - Name: src.Name, - Query: src.Query, - ParameterOIDs: src.ParameterOIDs, - }) -} diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go deleted file mode 100644 index 462a89ba..00000000 --- a/pgproto3/parse_complete.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgproto3 - -import ( - "encoding/json" -) - -type ParseComplete struct{} - -func (*ParseComplete) Backend() {} - -func (dst *ParseComplete) Decode(src []byte) error { - if len(src) != 0 { - return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} - } - - return nil -} - -func (src *ParseComplete) Encode(dst []byte) []byte { - return append(dst, '1', 0, 0, 0, 4) -} - -func (src *ParseComplete) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - }{ - Type: "ParseComplete", - }) -} diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go deleted file mode 100644 index 2ad3fe4a..00000000 --- a/pgproto3/password_message.go +++ /dev/null @@ -1,46 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type PasswordMessage struct { - Password string -} - -func (*PasswordMessage) Frontend() {} - -func (dst *PasswordMessage) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - - b, err := buf.ReadBytes(0) - if err != nil { - return err - } - dst.Password = string(b[:len(b)-1]) - - return nil -} - -func (src *PasswordMessage) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) - - dst = append(dst, src.Password...) - dst = append(dst, 0) - - return dst -} - -func (src *PasswordMessage) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - Password string - }{ - Type: "PasswordMessage", - Password: src.Password, - }) -} diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go deleted file mode 100644 index fe7b085b..00000000 --- a/pgproto3/pgproto3.go +++ /dev/null @@ -1,42 +0,0 @@ -package pgproto3 - -import "fmt" - -// Message is the interface implemented by an object that can decode and encode -// a particular PostgreSQL message. -type Message interface { - // Decode is allowed and expected to retain a reference to data after - // returning (unlike encoding.BinaryUnmarshaler). - Decode(data []byte) error - - // Encode appends itself to dst and returns the new buffer. - Encode(dst []byte) []byte -} - -type FrontendMessage interface { - Message - Frontend() // no-op method to distinguish frontend from backend methods -} - -type BackendMessage interface { - Message - Backend() // no-op method to distinguish frontend from backend methods -} - -type invalidMessageLenErr struct { - messageType string - expectedLen int - actualLen int -} - -func (e *invalidMessageLenErr) Error() string { - return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen) -} - -type invalidMessageFormatErr struct { - messageType string -} - -func (e *invalidMessageFormatErr) Error() string { - return fmt.Sprintf("%s body is invalid", e.messageType) -} diff --git a/pgproto3/query.go b/pgproto3/query.go deleted file mode 100644 index d80c0fb4..00000000 --- a/pgproto3/query.go +++ /dev/null @@ -1,45 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -type Query struct { - String string -} - -func (*Query) Frontend() {} - -func (dst *Query) Decode(src []byte) error { - i := bytes.IndexByte(src, 0) - if i != len(src)-1 { - return &invalidMessageFormatErr{messageType: "Query"} - } - - dst.String = string(src[:i]) - - return nil -} - -func (src *Query) Encode(dst []byte) []byte { - dst = append(dst, 'Q') - dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) - - dst = append(dst, src.String...) - dst = append(dst, 0) - - return dst -} - -func (src *Query) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - String string - }{ - Type: "Query", - String: src.String, - }) -} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go deleted file mode 100644 index 63b902bd..00000000 --- a/pgproto3/ready_for_query.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgproto3 - -import ( - "encoding/json" -) - -type ReadyForQuery struct { - TxStatus byte -} - -func (*ReadyForQuery) Backend() {} - -func (dst *ReadyForQuery) Decode(src []byte) error { - if len(src) != 1 { - return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} - } - - dst.TxStatus = src[0] - - return nil -} - -func (src *ReadyForQuery) Encode(dst []byte) []byte { - return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) -} - -func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - TxStatus string - }{ - Type: "ReadyForQuery", - TxStatus: string(src.TxStatus), - }) -} diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go deleted file mode 100644 index 7deba379..00000000 --- a/pgproto3/row_description.go +++ /dev/null @@ -1,100 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "encoding/json" - - "github.com/jackc/pgx/pgio" -) - -const ( - TextFormat = 0 - BinaryFormat = 1 -) - -type FieldDescription struct { - Name string - TableOID uint32 - TableAttributeNumber uint16 - DataTypeOID uint32 - DataTypeSize int16 - TypeModifier int32 - Format int16 -} - -type RowDescription struct { - Fields []FieldDescription -} - -func (*RowDescription) Backend() {} - -func (dst *RowDescription) Decode(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 2 { - return &invalidMessageFormatErr{messageType: "RowDescription"} - } - fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) - - dst.Fields = dst.Fields[0:0] - - for i := 0; i < fieldCount; i++ { - var fd FieldDescription - bName, err := buf.ReadBytes(0) - if err != nil { - return err - } - fd.Name = string(bName[:len(bName)-1]) - - // Since buf.Next() doesn't return an error if we hit the end of the buffer - // check Len ahead of time - if buf.Len() < 18 { - return &invalidMessageFormatErr{messageType: "RowDescription"} - } - - fd.TableOID = binary.BigEndian.Uint32(buf.Next(4)) - fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2)) - fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4)) - fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2))) - fd.TypeModifier = int32(binary.BigEndian.Uint32(buf.Next(4))) - fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) - - dst.Fields = append(dst.Fields, fd) - } - - return nil -} - -func (src *RowDescription) Encode(dst []byte) []byte { - dst = append(dst, 'T') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) - for _, fd := range src.Fields { - dst = append(dst, fd.Name...) - dst = append(dst, 0) - - dst = pgio.AppendUint32(dst, fd.TableOID) - dst = pgio.AppendUint16(dst, fd.TableAttributeNumber) - dst = pgio.AppendUint32(dst, fd.DataTypeOID) - dst = pgio.AppendInt16(dst, fd.DataTypeSize) - dst = pgio.AppendInt32(dst, fd.TypeModifier) - dst = pgio.AppendInt16(dst, fd.Format) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *RowDescription) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - Fields []FieldDescription - }{ - Type: "RowDescription", - Fields: src.Fields, - }) -} diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go deleted file mode 100644 index 6c5d4f99..00000000 --- a/pgproto3/startup_message.go +++ /dev/null @@ -1,97 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "encoding/json" - - "github.com/jackc/pgx/pgio" - "github.com/pkg/errors" -) - -const ( - ProtocolVersionNumber = 196608 // 3.0 - sslRequestNumber = 80877103 -) - -type StartupMessage struct { - ProtocolVersion uint32 - Parameters map[string]string -} - -func (*StartupMessage) Frontend() {} - -func (dst *StartupMessage) Decode(src []byte) error { - if len(src) < 4 { - return errors.Errorf("startup message too short") - } - - dst.ProtocolVersion = binary.BigEndian.Uint32(src) - rp := 4 - - if dst.ProtocolVersion == sslRequestNumber { - return errors.Errorf("can't handle ssl connection request") - } - - if dst.ProtocolVersion != ProtocolVersionNumber { - return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) - } - - dst.Parameters = make(map[string]string) - for { - idx := bytes.IndexByte(src[rp:], 0) - if idx < 0 { - return &invalidMessageFormatErr{messageType: "StartupMesage"} - } - key := string(src[rp : rp+idx]) - rp += idx + 1 - - idx = bytes.IndexByte(src[rp:], 0) - if idx < 0 { - return &invalidMessageFormatErr{messageType: "StartupMesage"} - } - value := string(src[rp : rp+idx]) - rp += idx + 1 - - dst.Parameters[key] = value - - if len(src[rp:]) == 1 { - if src[rp] != 0 { - return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) - } - break - } - } - - return nil -} - -func (src *StartupMessage) Encode(dst []byte) []byte { - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - - dst = pgio.AppendUint32(dst, src.ProtocolVersion) - for k, v := range src.Parameters { - dst = append(dst, k...) - dst = append(dst, 0) - dst = append(dst, v...) - dst = append(dst, 0) - } - dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} - -func (src *StartupMessage) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - ProtocolVersion uint32 - Parameters map[string]string - }{ - Type: "StartupMessage", - ProtocolVersion: src.ProtocolVersion, - Parameters: src.Parameters, - }) -} diff --git a/pgproto3/sync.go b/pgproto3/sync.go deleted file mode 100644 index 85f4749a..00000000 --- a/pgproto3/sync.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgproto3 - -import ( - "encoding/json" -) - -type Sync struct{} - -func (*Sync) Frontend() {} - -func (dst *Sync) Decode(src []byte) error { - if len(src) != 0 { - return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)} - } - - return nil -} - -func (src *Sync) Encode(dst []byte) []byte { - return append(dst, 'S', 0, 0, 0, 4) -} - -func (src *Sync) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - }{ - Type: "Sync", - }) -} diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go deleted file mode 100644 index 0a3310da..00000000 --- a/pgproto3/terminate.go +++ /dev/null @@ -1,29 +0,0 @@ -package pgproto3 - -import ( - "encoding/json" -) - -type Terminate struct{} - -func (*Terminate) Frontend() {} - -func (dst *Terminate) Decode(src []byte) error { - if len(src) != 0 { - return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)} - } - - return nil -} - -func (src *Terminate) Encode(dst []byte) []byte { - return append(dst, 'X', 0, 0, 0, 4) -} - -func (src *Terminate) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Type string - }{ - Type: "Terminate", - }) -} diff --git a/pgtype/array.go b/pgtype/array.go index 5b852ed5..9ce0f003 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -8,7 +8,7 @@ import ( "strings" "unicode" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 4231e29d..623937dc 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/box.go b/pgtype/box.go index 4c5a4406..4c825c56 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/bpchar_array.go b/pgtype/bpchar_array.go index b3f36cb6..d1ee2419 100644 --- a/pgtype/bpchar_array.go +++ b/pgtype/bpchar_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 9c094b28..68122961 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index c254c834..338d4904 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "net" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/circle.go b/pgtype/circle.go index 15ea447b..a3bb56f1 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/date.go b/pgtype/date.go index b1d4c11d..85c698aa 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/date_array.go b/pgtype/date_array.go index c0f5c21c..d04666f1 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/daterange.go b/pgtype/daterange.go index 47cd7e46..d10d34c0 100644 --- a/pgtype/daterange.go +++ b/pgtype/daterange.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/float4.go b/pgtype/float4.go index 2207594a..c4feb0a7 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -6,7 +6,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index fba181d3..4e07ba43 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/float8.go b/pgtype/float8.go index dd34f541..63944d45 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -6,7 +6,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index 13dbf27f..e4c340b2 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 71b030f9..754c5a3f 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -10,7 +10,7 @@ import ( "github.com/pkg/errors" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) // Hstore represents an hstore column that can be null or have null values diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index 2b8cf37e..239c5d9c 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index dba369d2..7b4cf457 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "net" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/int2.go b/pgtype/int2.go index 6156ea77..72110684 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -6,7 +6,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 7fefbd95..5b4c2e1a 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/int4.go b/pgtype/int4.go index 261c5118..9ad878c4 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -7,7 +7,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 86656524..77ad8654 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/int4range.go b/pgtype/int4range.go index 95ad1521..67bbfcd2 100644 --- a/pgtype/int4range.go +++ b/pgtype/int4range.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/int8.go b/pgtype/int8.go index 00a8cd00..39b8a0a8 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -7,7 +7,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 15a8398a..03b169d2 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/int8range.go b/pgtype/int8range.go index 61d860d3..25839a7b 100644 --- a/pgtype/int8range.go +++ b/pgtype/int8range.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/interval.go b/pgtype/interval.go index dc696319..75969904 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -8,7 +8,7 @@ import ( "strings" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/line.go b/pgtype/line.go index 5fdc5604..6ac4ac2a 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 4445ea51..c0e77799 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/macaddr_array.go b/pgtype/macaddr_array.go index bd8b4c5a..c6bc2450 100644 --- a/pgtype/macaddr_array.go +++ b/pgtype/macaddr_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "net" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index fb63df75..fb6e1a00 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go index b5e38539..0d26f3b5 100644 --- a/pgtype/numeric_array.go +++ b/pgtype/numeric_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/numrange.go b/pgtype/numrange.go index aaed62ce..ff9d5372 100644 --- a/pgtype/numrange.go +++ b/pgtype/numrange.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/oid.go b/pgtype/oid.go index 59370d66..2afc60f8 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/path.go b/pgtype/path.go index 69083712..c1b72322 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index e441a690..37178b5c 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -6,7 +6,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/point.go b/pgtype/point.go index 98a32d34..fefe5d1f 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/polygon.go b/pgtype/polygon.go index d84a0abd..904e86e1 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/text_array.go b/pgtype/text_array.go index d53f0b7b..ec487a23 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/tid.go b/pgtype/tid.go index 21852a14..e859865b 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 6292521a..f8a4070d 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index 11b32a11..493088a2 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 2b9d2a64..ca9b538d 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 31c11f94..612e9904 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -5,7 +5,7 @@ import ( "encoding/binary" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go index 8a67d65e..d771a761 100644 --- a/pgtype/tsrange.go +++ b/pgtype/tsrange.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go index b5129093..9a8c782e 100644 --- a/pgtype/tstzrange.go +++ b/pgtype/tstzrange.go @@ -3,7 +3,7 @@ package pgtype import ( "database/sql/driver" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 6b46a23e..b33e7d99 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -5,7 +5,7 @@ import ( "fmt" "io" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type <%= pgtype_array_type %> struct { diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb index 91a5cb97..035a71af 100644 --- a/pgtype/typed_range.go.erb +++ b/pgtype/typed_range.go.erb @@ -6,7 +6,7 @@ import ( "fmt" "io" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" ) type <%= range_type %> struct { diff --git a/pgtype/uuid_array.go b/pgtype/uuid_array.go index 13efdb23..cddd62f1 100644 --- a/pgtype/uuid_array.go +++ b/pgtype/uuid_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/varbit.go b/pgtype/varbit.go index dfa194d2..2c25b1fb 100644 --- a/pgtype/varbit.go +++ b/pgtype/varbit.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index a7f23fba..0a929920 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "encoding/binary" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/pkg/errors" ) diff --git a/query.go b/query.go index c169db8d..5b301eaf 100644 --- a/query.go +++ b/query.go @@ -9,8 +9,8 @@ import ( "github.com/pkg/errors" + "github.com/jackc/pgconn" "github.com/jackc/pgx/internal/sanitize" - "github.com/jackc/pgx/pgconn" "github.com/jackc/pgx/pgtype" ) diff --git a/query_test.go b/query_test.go index 726061ec..2d638784 100644 --- a/query_test.go +++ b/query_test.go @@ -12,8 +12,8 @@ import ( "time" "github.com/cockroachdb/apd" + "github.com/jackc/pgconn" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgconn" "github.com/jackc/pgx/pgtype" satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" uuid "github.com/satori/go.uuid" diff --git a/replication.go b/replication.go index 21d9a3d8..5493311e 100644 --- a/replication.go +++ b/replication.go @@ -9,9 +9,9 @@ import ( "github.com/pkg/errors" - "github.com/jackc/pgx/pgconn" - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgconn" + "github.com/jackc/pgio" + "github.com/jackc/pgproto3" "github.com/jackc/pgx/pgtype" ) diff --git a/replication_test.go b/replication_test.go index def75b7c..e68dc413 100644 --- a/replication_test.go +++ b/replication_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" + "github.com/jackc/pgconn" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgconn" ) // This function uses a postgresql 9.6 specific column diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index cf2b91b1..be6d9e6f 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -12,9 +12,9 @@ import ( "testing" "time" + "github.com/jackc/pgproto3" "github.com/jackc/pgx" "github.com/jackc/pgx/pgmock" - "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/stdlib" ) diff --git a/stress_test.go b/stress_test.go index eb2e9b31..182b725d 100644 --- a/stress_test.go +++ b/stress_test.go @@ -13,7 +13,7 @@ package pgx_test // "github.com/jackc/fake" // "github.com/jackc/pgx" -// "github.com/jackc/pgx/pgconn" +// "github.com/jackc/pgconn" // ) // type execer interface { diff --git a/tx.go b/tx.go index a045d6ab..cc3b2fa7 100644 --- a/tx.go +++ b/tx.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/pkg/errors" ) diff --git a/tx_test.go b/tx_test.go index 908a4d34..4b6142fe 100644 --- a/tx_test.go +++ b/tx_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgconn" "github.com/jackc/pgx/pgmock" - "github.com/jackc/pgx/pgproto3" ) func TestTransactionSuccessfulCommit(t *testing.T) { diff --git a/values.go b/values.go index 0c571d74..fc36f678 100644 --- a/values.go +++ b/values.go @@ -7,7 +7,7 @@ import ( "reflect" "time" - "github.com/jackc/pgx/pgio" + "github.com/jackc/pgio" "github.com/jackc/pgx/pgtype" "github.com/pkg/errors" )