Use extracted packages with Go modules
This commit is contained in:
@@ -56,7 +56,7 @@ database/sql compatibility layer for pgx. pgx can be used as a normal database/s
|
||||
|
||||
Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver.
|
||||
|
||||
## github.com/jackc/pgx/pgproto3
|
||||
## github.com/jackc/pgproto3
|
||||
|
||||
pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling.
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package pgx
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
+1
-1
@@ -5,8 +5,8 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
package chunkreader
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type ChunkReader struct {
|
||||
r io.Reader
|
||||
|
||||
buf []byte
|
||||
rp, wp int // buf read position and write position
|
||||
|
||||
options Options
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
MinBufLen int // Minimum buffer length
|
||||
}
|
||||
|
||||
func NewChunkReader(r io.Reader) *ChunkReader {
|
||||
cr, err := NewChunkReaderEx(r, Options{})
|
||||
if err != nil {
|
||||
panic("default options can't be bad")
|
||||
}
|
||||
|
||||
return cr
|
||||
}
|
||||
|
||||
func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) {
|
||||
if options.MinBufLen == 0 {
|
||||
options.MinBufLen = 4096
|
||||
}
|
||||
|
||||
return &ChunkReader{
|
||||
r: r,
|
||||
buf: make([]byte, options.MinBufLen),
|
||||
options: options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Next returns buf filled with the next n bytes. If an error occurs, buf will
|
||||
// be nil.
|
||||
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
|
||||
// n bytes already in buf
|
||||
if (r.wp - r.rp) >= n {
|
||||
buf = r.buf[r.rp : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, err
|
||||
}
|
||||
|
||||
// available space in buf is less than n
|
||||
if len(r.buf) < n {
|
||||
r.copyBufContents(r.newBuf(n))
|
||||
}
|
||||
|
||||
// buf is large enough, but need to shift filled area to start to make enough contiguous space
|
||||
minReadCount := n - (r.wp - r.rp)
|
||||
if (len(r.buf) - r.wp) < minReadCount {
|
||||
newBuf := r.newBuf(n)
|
||||
r.copyBufContents(newBuf)
|
||||
}
|
||||
|
||||
if err := r.appendAtLeast(minReadCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf = r.buf[r.rp : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (r *ChunkReader) appendAtLeast(fillLen int) error {
|
||||
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
|
||||
r.wp += n
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ChunkReader) newBuf(size int) []byte {
|
||||
if size < r.options.MinBufLen {
|
||||
size = r.options.MinBufLen
|
||||
}
|
||||
return make([]byte, size)
|
||||
}
|
||||
|
||||
func (r *ChunkReader) copyBufContents(dest []byte) {
|
||||
r.wp = copy(dest, r.buf[r.rp:r.wp])
|
||||
r.rp = 0
|
||||
r.buf = dest
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
package chunkreader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte{1, 2, 3, 4}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:2]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
|
||||
}
|
||||
|
||||
n2, err := r.Next(2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n2, src[2:4]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
|
||||
}
|
||||
|
||||
if bytes.Compare(r.buf, src) != 0 {
|
||||
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
|
||||
}
|
||||
if r.rp != 4 {
|
||||
t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp)
|
||||
}
|
||||
if r.wp != 4 {
|
||||
t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:5]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1)
|
||||
}
|
||||
if len(r.buf) != 5 {
|
||||
t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkReaderDoesNotReuseBuf(t *testing.T) {
|
||||
server := &bytes.Buffer{}
|
||||
r, err := NewChunkReaderEx(server, Options{MinBufLen: 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
server.Write(src)
|
||||
|
||||
n1, err := r.Next(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n1, src[0:4]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1)
|
||||
}
|
||||
|
||||
n2, err := r.Next(4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Compare(n2, src[4:8]) != 0 {
|
||||
t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2)
|
||||
}
|
||||
|
||||
if bytes.Compare(n1, src[0:4]) != 0 {
|
||||
t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1)
|
||||
}
|
||||
}
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -11,8 +11,8 @@ import (
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
)
|
||||
|
||||
func createConnPool(t *testing.T, maxConnections int) *pgx.ConnPool {
|
||||
|
||||
+1
-1
@@ -10,8 +10,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
+1
-1
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -6,8 +6,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+2
-2
@@ -3,8 +3,8 @@ package pgx
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/jackc/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
module github.com/jackc/pgx
|
||||
|
||||
go 1.12
|
||||
|
||||
require (
|
||||
github.com/cockroachdb/apd v1.1.0
|
||||
github.com/jackc/pgconn v0.0.0-20190330221323-ed7d91dc9873
|
||||
github.com/jackc/pgio v1.0.0
|
||||
github.com/jackc/pgproto3 v1.0.0
|
||||
github.com/pkg/errors v0.8.1
|
||||
github.com/satori/go.uuid v1.2.0
|
||||
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24
|
||||
github.com/stretchr/testify v1.3.0
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
|
||||
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
||||
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
||||
github.com/jackc/pgconn v0.0.0-20190330221323-ed7d91dc9873 h1:M68R77AKFS7dub7R7WgJ9D6yiNuExOYhBuGtazlbr10=
|
||||
github.com/jackc/pgconn v0.0.0-20190330221323-ed7d91dc9873/go.mod h1:8Bzf8vzi/ZpcgLgrq8IUHjZX4ZU+Hf6N6/AJ85+fDeE=
|
||||
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
||||
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgproto3 v1.0.0 h1:25tUmlES7eyD96oYaUHc1dLOFbgcJtFzCdnOOoqmA1I=
|
||||
github.com/jackc/pgproto3 v1.0.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
|
||||
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE=
|
||||
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
+1
-1
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
)
|
||||
|
||||
func TestLargeObjects(t *testing.T) {
|
||||
|
||||
+1
-1
@@ -6,7 +6,7 @@ import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func BenchmarkConnect(b *testing.B) {
|
||||
benchmarks := []struct {
|
||||
name string
|
||||
env string
|
||||
}{
|
||||
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
||||
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
connString := os.Getenv(bm.env)
|
||||
if connString == "" {
|
||||
b.Skipf("Skipping due to missing environment variable %v", bm.env)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pgconn.Connect(context.Background(), connString)
|
||||
require.Nil(b, err)
|
||||
|
||||
err = conn.Close(context.Background())
|
||||
require.Nil(b, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExec(b *testing.B) {
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll()
|
||||
require.Nil(b, err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecPossibleToCancel(b *testing.B) {
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll()
|
||||
require.Nil(b, err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecPrepared(b *testing.B) {
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||
require.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).Read()
|
||||
require.Nil(b, result.Err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
|
||||
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||
require.Nil(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
||||
require.Nil(b, result.Err)
|
||||
}
|
||||
}
|
||||
@@ -1,501 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgpassfile"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server.
|
||||
type Config struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
Fallbacks []*FallbackConfig
|
||||
|
||||
// AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that
|
||||
// server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This
|
||||
// allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||
AfterConnectFunc AfterConnectFunc
|
||||
|
||||
// OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context
|
||||
// is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a
|
||||
// query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire
|
||||
// protocol do not support this cancellation method.
|
||||
//
|
||||
// It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be
|
||||
// called whether it was successful or not. If an error occurs the connection should be closed. The connection must be
|
||||
// in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read
|
||||
// the connection until a ready for query message is received.
|
||||
OnContextCancel func(*ContextCancel)
|
||||
|
||||
// OnNotice is a callback function called when a notice response is received.
|
||||
OnNotice NoticeHandler
|
||||
|
||||
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
||||
OnNotification NotificationHandler
|
||||
}
|
||||
|
||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
||||
type FallbackConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
}
|
||||
|
||||
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||
// net.Dial.
|
||||
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
if strings.HasPrefix(host, "/") {
|
||||
network = "unix"
|
||||
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||
} else {
|
||||
network = "tcp"
|
||||
address = fmt.Sprintf("%s:%d", host, port)
|
||||
}
|
||||
return network, address
|
||||
}
|
||||
|
||||
// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same
|
||||
// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN.
|
||||
// It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the
|
||||
// .pgpass file.
|
||||
//
|
||||
// # Example DSN
|
||||
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
||||
//
|
||||
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
|
||||
// values that will be tried in order. This can be used as part of a high availability system. See
|
||||
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||
//
|
||||
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
|
||||
// via database URL or DSN:
|
||||
//
|
||||
// PGHOST
|
||||
// PGPORT
|
||||
// PGDATABASE
|
||||
// PGUSER
|
||||
// PGPASSWORD
|
||||
// PGPASSFILE
|
||||
// PGSSLMODE
|
||||
// PGSSLCERT
|
||||
// PGSSLKEY
|
||||
// PGSSLROOTCERT
|
||||
// PGAPPNAME
|
||||
// PGCONNECT_TIMEOUT
|
||||
// PGTARGETSESSIONATTRS
|
||||
//
|
||||
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
|
||||
//
|
||||
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
|
||||
// usually but not always the environment variable name downcased and without the "PG" prefix.
|
||||
//
|
||||
// Important TLS Security Notes:
|
||||
//
|
||||
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
|
||||
// not set.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
|
||||
// security each sslmode provides.
|
||||
//
|
||||
// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger
|
||||
// security guarantees than it would with libpq. Do not rely on this behavior as it
|
||||
// may be possible to match libpq in the future. If you need full security use
|
||||
// "verify-full".
|
||||
func ParseConfig(connString string) (*Config, error) {
|
||||
settings := defaultSettings()
|
||||
addEnvSettings(settings)
|
||||
|
||||
if connString != "" {
|
||||
// connString may be a database URL or a DSN
|
||||
if strings.HasPrefix(connString, "postgres://") {
|
||||
err := addURLSettings(settings, connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
err := addDSNSettings(settings, connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Database: settings["database"],
|
||||
User: settings["user"],
|
||||
Password: settings["password"],
|
||||
RuntimeParams: make(map[string]string),
|
||||
}
|
||||
|
||||
if connectTimeout, present := settings["connect_timeout"]; present {
|
||||
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.DialFunc = dialFunc
|
||||
} else {
|
||||
defaultDialer := makeDefaultDialer()
|
||||
config.DialFunc = defaultDialer.DialContext
|
||||
}
|
||||
|
||||
notRuntimeParams := map[string]struct{}{
|
||||
"host": struct{}{},
|
||||
"port": struct{}{},
|
||||
"database": struct{}{},
|
||||
"user": struct{}{},
|
||||
"password": struct{}{},
|
||||
"passfile": struct{}{},
|
||||
"connect_timeout": struct{}{},
|
||||
"sslmode": struct{}{},
|
||||
"sslkey": struct{}{},
|
||||
"sslcert": struct{}{},
|
||||
"sslrootcert": struct{}{},
|
||||
"target_session_attrs": struct{}{},
|
||||
}
|
||||
|
||||
for k, v := range settings {
|
||||
if _, present := notRuntimeParams[k]; present {
|
||||
continue
|
||||
}
|
||||
config.RuntimeParams[k] = v
|
||||
}
|
||||
|
||||
fallbacks := []*FallbackConfig{}
|
||||
|
||||
hosts := strings.Split(settings["host"], ",")
|
||||
ports := strings.Split(settings["port"], ",")
|
||||
|
||||
for i, host := range hosts {
|
||||
var portStr string
|
||||
if i < len(ports) {
|
||||
portStr = ports[i]
|
||||
} else {
|
||||
portStr = ports[0]
|
||||
}
|
||||
|
||||
port, err := parsePort(portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port: %v", settings["port"])
|
||||
}
|
||||
|
||||
var tlsConfigs []*tls.Config
|
||||
|
||||
// Ignore TLS settings if Unix domain socket like libpq
|
||||
if network, _ := NetworkAddress(host, port); network == "unix" {
|
||||
tlsConfigs = append(tlsConfigs, nil)
|
||||
} else {
|
||||
var err error
|
||||
tlsConfigs, err = configTLS(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, tlsConfig := range tlsConfigs {
|
||||
fallbacks = append(fallbacks, &FallbackConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
config.Host = fallbacks[0].Host
|
||||
config.Port = fallbacks[0].Port
|
||||
config.TLSConfig = fallbacks[0].TLSConfig
|
||||
config.Fallbacks = fallbacks[1:]
|
||||
|
||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||
if err == nil {
|
||||
if config.Password == "" {
|
||||
host := config.Host
|
||||
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
|
||||
}
|
||||
}
|
||||
|
||||
if settings["target_session_attrs"] == "read-write" {
|
||||
config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite
|
||||
} else if settings["target_session_attrs"] != "any" {
|
||||
return nil, fmt.Errorf("unknown target_session_attrs value %v", settings["target_session_attrs"])
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func defaultSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
settings["host"] = defaultHost()
|
||||
settings["port"] = "5432"
|
||||
|
||||
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||
// OS. The client application will simply have to specify the user in that
|
||||
// case (which they typically will be doing anyway).
|
||||
user, err := user.Current()
|
||||
if err == nil {
|
||||
settings["user"] = user.Username
|
||||
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||
}
|
||||
|
||||
settings["target_session_attrs"] = "any"
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||
// checks the existence of common locations.
|
||||
func defaultHost() string {
|
||||
candidatePaths := []string{
|
||||
"/var/run/postgresql", // Debian
|
||||
"/private/tmp", // OSX - homebrew
|
||||
"/tmp", // standard PostgreSQL
|
||||
}
|
||||
|
||||
for _, path := range candidatePaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
return "localhost"
|
||||
}
|
||||
|
||||
func addEnvSettings(settings map[string]string) {
|
||||
nameMap := map[string]string{
|
||||
"PGHOST": "host",
|
||||
"PGPORT": "port",
|
||||
"PGDATABASE": "database",
|
||||
"PGUSER": "user",
|
||||
"PGPASSWORD": "password",
|
||||
"PGPASSFILE": "passfile",
|
||||
"PGAPPNAME": "application_name",
|
||||
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||
"PGSSLMODE": "sslmode",
|
||||
"PGSSLKEY": "sslkey",
|
||||
"PGSSLCERT": "sslcert",
|
||||
"PGSSLROOTCERT": "sslrootcert",
|
||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||
}
|
||||
|
||||
for envname, realname := range nameMap {
|
||||
value := os.Getenv(envname)
|
||||
if value != "" {
|
||||
settings[realname] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addURLSettings(settings map[string]string, connString string) error {
|
||||
url, err := url.Parse(connString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if url.User != nil {
|
||||
settings["user"] = url.User.Username()
|
||||
if password, present := url.User.Password(); present {
|
||||
settings["password"] = password
|
||||
}
|
||||
}
|
||||
|
||||
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||
var hosts []string
|
||||
var ports []string
|
||||
for _, host := range strings.Split(url.Host, ",") {
|
||||
parts := strings.SplitN(host, ":", 2)
|
||||
if parts[0] != "" {
|
||||
hosts = append(hosts, parts[0])
|
||||
}
|
||||
if len(parts) == 2 {
|
||||
ports = append(ports, parts[1])
|
||||
}
|
||||
}
|
||||
if len(hosts) > 0 {
|
||||
settings["host"] = strings.Join(hosts, ",")
|
||||
}
|
||||
if len(ports) > 0 {
|
||||
settings["port"] = strings.Join(ports, ",")
|
||||
}
|
||||
|
||||
database := strings.TrimLeft(url.Path, "/")
|
||||
if database != "" {
|
||||
settings["database"] = database
|
||||
}
|
||||
|
||||
for k, v := range url.Query() {
|
||||
settings[k] = v[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
||||
|
||||
func addDSNSettings(settings map[string]string, s string) error {
|
||||
m := dsnRegexp.FindAllStringSubmatch(s, -1)
|
||||
|
||||
for _, b := range m {
|
||||
settings[b[1]] = b[2]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type pgTLSArgs struct {
|
||||
sslMode string
|
||||
sslRootCert string
|
||||
sslCert string
|
||||
sslKey string
|
||||
}
|
||||
|
||||
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||
// "prefer" allow fallback.
|
||||
func configTLS(settings map[string]string) ([]*tls.Config, error) {
|
||||
host := settings["host"]
|
||||
sslmode := settings["sslmode"]
|
||||
sslrootcert := settings["sslrootcert"]
|
||||
sslcert := settings["sslcert"]
|
||||
sslkey := settings["sslkey"]
|
||||
|
||||
// Match libpq default behavior
|
||||
if sslmode == "" {
|
||||
sslmode = "prefer"
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
switch sslmode {
|
||||
case "disable":
|
||||
return []*tls.Config{nil}, nil
|
||||
case "allow", "prefer":
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
case "require":
|
||||
tlsConfig.InsecureSkipVerify = sslrootcert == ""
|
||||
case "verify-ca", "verify-full":
|
||||
tlsConfig.ServerName = host
|
||||
default:
|
||||
return nil, errors.New("sslmode is invalid")
|
||||
}
|
||||
|
||||
if sslrootcert != "" {
|
||||
caCertPool := x509.NewCertPool()
|
||||
|
||||
caPath := sslrootcert
|
||||
caCert, err := ioutil.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "unable to read CA file %q", caPath)
|
||||
}
|
||||
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, errors.Wrap(err, "unable to add CA to cert pool")
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
}
|
||||
|
||||
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||
return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`)
|
||||
}
|
||||
|
||||
if sslcert != "" && sslkey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unable to read cert")
|
||||
}
|
||||
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
switch sslmode {
|
||||
case "allow":
|
||||
return []*tls.Config{nil, tlsConfig}, nil
|
||||
case "prefer":
|
||||
return []*tls.Config{tlsConfig, nil}, nil
|
||||
case "require", "verify-ca", "verify-full":
|
||||
return []*tls.Config{tlsConfig}, nil
|
||||
default:
|
||||
panic("BUG: bad sslmode should already have been caught")
|
||||
}
|
||||
}
|
||||
|
||||
func parsePort(s string) (uint16, error) {
|
||||
port, err := strconv.ParseUint(s, 10, 16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if port < 1 || port > math.MaxUint16 {
|
||||
return 0, errors.New("outside range")
|
||||
}
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
func makeDefaultDialer() *net.Dialer {
|
||||
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
}
|
||||
|
||||
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
|
||||
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if timeout < 0 {
|
||||
return nil, errors.New("negative timeout")
|
||||
}
|
||||
|
||||
d := makeDefaultDialer()
|
||||
d.Timeout = time.Duration(timeout) * time.Second
|
||||
return d.DialContext, nil
|
||||
}
|
||||
|
||||
// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=read-write.
|
||||
func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) == "on" {
|
||||
return errors.New("read only connection")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,562 +0,0 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/user"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var osUserName string
|
||||
osUser, err := user.Current()
|
||||
if err == nil {
|
||||
osUserName = osUser.Username
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connString string
|
||||
config *pgconn.Config
|
||||
}{
|
||||
// Test all sslmodes
|
||||
{
|
||||
name: "sslmode not set (prefer)",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode disable",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode allow",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode prefer",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
|
||||
config: &pgconn.Config{
|
||||
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode require",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode verify-ca",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{ServerName: "localhost"},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode verify-full",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{ServerName: "localhost"},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url everything",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{
|
||||
"application_name": "pgxtest",
|
||||
"search_path": "myschema",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url missing password",
|
||||
connString: "postgres://jack@localhost:5432/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url missing user and password",
|
||||
connString: "postgres://localhost:5432/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url missing port",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url unix domain socket host",
|
||||
connString: "postgres:///foo?host=/tmp",
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
Host: "/tmp",
|
||||
Port: 5432,
|
||||
Database: "foo",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN everything",
|
||||
connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{
|
||||
"application_name": "pgxtest",
|
||||
"search_path": "myschema",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "URL multiple hosts",
|
||||
connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "foo",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "bar",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "baz",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "URL multiple hosts and ports",
|
||||
connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "foo",
|
||||
Port: 1,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "bar",
|
||||
Port: 2,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "baz",
|
||||
Port: 3,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN multiple hosts one port",
|
||||
connString: "user=jack password=secret host=foo,bar,baz port=5432 database=mydb sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "foo",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "bar",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "baz",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN multiple hosts multiple ports",
|
||||
connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 database=mydb sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "foo",
|
||||
Port: 1,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "bar",
|
||||
Port: 2,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "baz",
|
||||
Port: 3,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple hosts and fallback tsl",
|
||||
connString: "user=jack password=secret host=foo,bar,baz database=mydb sslmode=prefer",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "foo",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "foo",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "bar",
|
||||
Port: 5432,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}},
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "bar",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "baz",
|
||||
Port: 5432,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}},
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "baz",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "target_session_attrs",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
config, err := pgconn.ParseConfig(tt.connString)
|
||||
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
|
||||
continue
|
||||
}
|
||||
|
||||
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
||||
}
|
||||
}
|
||||
|
||||
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
|
||||
if !assert.NotNil(t, expected) {
|
||||
return
|
||||
}
|
||||
if !assert.NotNil(t, actual) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
|
||||
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
|
||||
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
||||
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
|
||||
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
||||
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
||||
|
||||
// Can't test function equality, so just test that they are set or not.
|
||||
assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName)
|
||||
|
||||
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
||||
if expected.TLSConfig != nil {
|
||||
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
|
||||
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
|
||||
}
|
||||
}
|
||||
|
||||
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
|
||||
for i := range expected.Fallbacks {
|
||||
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
|
||||
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
|
||||
|
||||
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
|
||||
if expected.Fallbacks[i].TLSConfig != nil {
|
||||
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
|
||||
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigEnvLibpq(t *testing.T) {
|
||||
var osUserName string
|
||||
osUser, err := user.Current()
|
||||
if err == nil {
|
||||
osUserName = osUser.Username
|
||||
}
|
||||
|
||||
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
|
||||
|
||||
savedEnv := make(map[string]string)
|
||||
for _, n := range pgEnvvars {
|
||||
savedEnv[n] = os.Getenv(n)
|
||||
}
|
||||
defer func() {
|
||||
for k, v := range savedEnv {
|
||||
err := os.Setenv(k, v)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to restore environment: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
envvars map[string]string
|
||||
config *pgconn.Config
|
||||
}{
|
||||
{
|
||||
// not testing no environment at all as that would use default host and that can vary.
|
||||
name: "PGHOST only",
|
||||
envvars: map[string]string{"PGHOST": "123.123.123.123"},
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
Host: "123.123.123.123",
|
||||
Port: 5432,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "123.123.123.123",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "All non-TLS environment",
|
||||
envvars: map[string]string{
|
||||
"PGHOST": "123.123.123.123",
|
||||
"PGPORT": "7777",
|
||||
"PGDATABASE": "foo",
|
||||
"PGUSER": "bar",
|
||||
"PGPASSWORD": "baz",
|
||||
"PGCONNECT_TIMEOUT": "10",
|
||||
"PGSSLMODE": "disable",
|
||||
"PGAPPNAME": "pgxtest",
|
||||
},
|
||||
config: &pgconn.Config{
|
||||
Host: "123.123.123.123",
|
||||
Port: 7777,
|
||||
Database: "foo",
|
||||
User: "bar",
|
||||
Password: "baz",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
for _, n := range pgEnvvars {
|
||||
err := os.Unsetenv(n)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
for k, v := range tt.envvars {
|
||||
err := os.Setenv(k, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
config, err := pgconn.ParseConfig("")
|
||||
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
|
||||
continue
|
||||
}
|
||||
|
||||
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigReadsPgPassfile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tf, err := ioutil.TempFile("", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
defer tf.Close()
|
||||
defer os.Remove(tf.Name())
|
||||
|
||||
_, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk"))
|
||||
require.NoError(t, err)
|
||||
|
||||
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name())
|
||||
expected := &pgconn.Config{
|
||||
User: "curly",
|
||||
Password: "nyuknyuknyuk",
|
||||
Host: "test1",
|
||||
Port: 5432,
|
||||
Database: "curlydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
}
|
||||
|
||||
actual, err := pgconn.ParseConfig(connString)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assertConfigsEqual(t, expected, actual, "passfile")
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
// Package pgconn is a low-level PostgreSQL database driver.
|
||||
/*
|
||||
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
|
||||
nearly the same level is the C library libpq.
|
||||
|
||||
Establishing a Connection
|
||||
|
||||
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
|
||||
libpq style environment variables.
|
||||
|
||||
Executing a Query
|
||||
|
||||
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
|
||||
reads all rows into memory.
|
||||
|
||||
Executing Multiple Queries in a Single Round Trip
|
||||
|
||||
Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query
|
||||
result. The ReadAll method reads all query results into memory.
|
||||
|
||||
Context Support
|
||||
|
||||
All potentially blocking operations take a context.Context. If a context is canceled while a query is in progress the
|
||||
method immediately returns. In the background a cancel request will be sent to the PostgreSQL server. If the
|
||||
cancellation fails or hangs for more than a short time (approximately 15 seconds) the connection will be closed. It is
|
||||
safe to use the connection while this background cancellation is in progress. Any calls will block until the
|
||||
cancellation and resynchronization is complete (and those calls can be aborted by a context cancellation).
|
||||
*/
|
||||
package pgconn
|
||||
@@ -1,31 +0,0 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func closeConn(t testing.TB, conn *pgconn.PgConn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
require.Nil(t, conn.Close(ctx))
|
||||
}
|
||||
|
||||
// Do a simple query to ensure the connection is still usable
|
||||
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
|
||||
cancel()
|
||||
|
||||
require.Nil(t, result.Err)
|
||||
assert.Equal(t, 3, len(result.Rows))
|
||||
assert.Equal(t, "1", string(result.Rows[0][0]))
|
||||
assert.Equal(t, "2", string(result.Rows[1][0]))
|
||||
assert.Equal(t, "3", string(result.Rows[2][0]))
|
||||
}
|
||||
-1407
File diff suppressed because it is too large
Load Diff
@@ -1,120 +0,0 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnStress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
actionCount := 100
|
||||
if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" {
|
||||
stressFactor, err := strconv.ParseInt(s, 10, 64)
|
||||
require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR")
|
||||
actionCount *= int(stressFactor)
|
||||
}
|
||||
|
||||
setupStressDB(t, pgConn)
|
||||
|
||||
actions := []struct {
|
||||
name string
|
||||
fn func(*pgconn.PgConn) error
|
||||
}{
|
||||
{"Exec Select", stressExecSelect},
|
||||
{"ExecParams Select", stressExecParamsSelect},
|
||||
{"Batch", stressBatch},
|
||||
{"ExecCanceled", stressExecSelectCanceled},
|
||||
{"ExecParamsCanceled", stressExecParamsSelectCanceled},
|
||||
{"BatchCanceled", stressBatchCanceled},
|
||||
}
|
||||
|
||||
for i := 0; i < actionCount; i++ {
|
||||
action := actions[rand.Intn(len(actions))]
|
||||
err := action.fn(pgConn)
|
||||
require.Nilf(t, err, "%d: %s", i, action.name)
|
||||
}
|
||||
}
|
||||
|
||||
func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) {
|
||||
_, err := pgConn.Exec(context.Background(), `
|
||||
create temporary table widgets(
|
||||
id serial primary key,
|
||||
name varchar not null,
|
||||
description text,
|
||||
creation_time timestamptz default now()
|
||||
);
|
||||
|
||||
insert into widgets(name, description) values
|
||||
('Foo', 'bar'),
|
||||
('baz', 'Something really long Something really long Something really long Something really long Something really long'),
|
||||
('a', 'b')`).ReadAll()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func stressExecSelect(pgConn *pgconn.PgConn) error {
|
||||
_, err := pgConn.Exec(context.Background(), "select * from widgets").ReadAll()
|
||||
return err
|
||||
}
|
||||
|
||||
func stressExecParamsSelect(pgConn *pgconn.PgConn) error {
|
||||
result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
|
||||
return result.Err
|
||||
}
|
||||
|
||||
func stressBatch(pgConn *pgconn.PgConn) error {
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
|
||||
batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
|
||||
_, err := pgConn.ExecBatch(context.Background(), batch).ReadAll()
|
||||
return err
|
||||
}
|
||||
|
||||
func stressExecSelectCanceled(pgConn *pgconn.PgConn) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
_, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets").ReadAll()
|
||||
cancel()
|
||||
if err != context.DeadlineExceeded {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
|
||||
cancel()
|
||||
if result.Err != context.DeadlineExceeded {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stressBatchCanceled(pgConn *pgconn.PgConn) error {
|
||||
batch := &pgconn.Batch{}
|
||||
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
|
||||
batch.ExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
_, err := pgConn.ExecBatch(ctx, batch).ReadAll()
|
||||
cancel()
|
||||
if err != context.DeadlineExceeded {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +0,0 @@
|
||||
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||
/*
|
||||
pgio provides functions for appending integers to a []byte while doing byte
|
||||
order conversion.
|
||||
*/
|
||||
package pgio
|
||||
@@ -1,40 +0,0 @@
|
||||
package pgio
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
func AppendUint16(buf []byte, n uint16) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0)
|
||||
binary.BigEndian.PutUint16(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendUint32(buf []byte, n uint32) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0, 0, 0)
|
||||
binary.BigEndian.PutUint32(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendUint64(buf []byte, n uint64) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
binary.BigEndian.PutUint64(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendInt16(buf []byte, n int16) []byte {
|
||||
return AppendUint16(buf, uint16(n))
|
||||
}
|
||||
|
||||
func AppendInt32(buf []byte, n int32) []byte {
|
||||
return AppendUint32(buf, uint32(n))
|
||||
}
|
||||
|
||||
func AppendInt64(buf []byte, n int64) []byte {
|
||||
return AppendUint64(buf, uint64(n))
|
||||
}
|
||||
|
||||
func SetInt32(buf []byte, n int32) {
|
||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package pgio
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAppendUint16NilBuf(t *testing.T) {
|
||||
buf := AppendUint16(nil, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint16EmptyBuf(t *testing.T) {
|
||||
buf := []byte{}
|
||||
buf = AppendUint16(buf, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||
buf := make([]byte, 0, 4)
|
||||
AppendUint16(buf, 1)
|
||||
buf = buf[0:2]
|
||||
if !reflect.DeepEqual(buf, []byte{0, 1}) {
|
||||
t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint32NilBuf(t *testing.T) {
|
||||
buf := AppendUint32(nil, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint32EmptyBuf(t *testing.T) {
|
||||
buf := []byte{}
|
||||
buf = AppendUint32(buf, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||
buf := make([]byte, 0, 4)
|
||||
AppendUint32(buf, 1)
|
||||
buf = buf[0:4]
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint64NilBuf(t *testing.T) {
|
||||
buf := AppendUint64(nil, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint64EmptyBuf(t *testing.T) {
|
||||
buf := []byte{}
|
||||
buf = AppendUint64(buf, 1)
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) {
|
||||
buf := make([]byte, 0, 8)
|
||||
AppendUint64(buf, 1)
|
||||
buf = buf[0:8]
|
||||
if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) {
|
||||
t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||
}
|
||||
}
|
||||
+2
-2
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/jackc/pgproto3"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
@@ -43,7 +43,7 @@ func (s *Server) ServeOne() error {
|
||||
|
||||
s.Close()
|
||||
|
||||
backend, err := pgproto3.NewBackend(conn, conn)
|
||||
backend, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
package pgpassfile
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Entry represents a line in a PG passfile.
|
||||
type Entry struct {
|
||||
Hostname string
|
||||
Port string
|
||||
Database string
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// Passfile is the in memory data structure representing a PG passfile.
|
||||
type Passfile struct {
|
||||
Entries []*Entry
|
||||
}
|
||||
|
||||
// ReadPassfile reads the file at path and parses it into a Passfile.
|
||||
func ReadPassfile(path string) (*Passfile, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return ParsePassfile(f)
|
||||
}
|
||||
|
||||
// ParsePassfile reads r and parses it into a Passfile.
|
||||
func ParsePassfile(r io.Reader) (*Passfile, error) {
|
||||
passfile := &Passfile{}
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
entry := parseLine(scanner.Text())
|
||||
if entry != nil {
|
||||
passfile.Entries = append(passfile.Entries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
return passfile, scanner.Err()
|
||||
}
|
||||
|
||||
// Match (not colons or escaped colon or escaped backslash)+. Essentially gives a split on unescaped
|
||||
// colon.
|
||||
var colonSplitterRegexp = regexp.MustCompile("(([^:]|(\\:)))+")
|
||||
|
||||
// var colonSplitterRegexp = regexp.MustCompile("((?:[^:]|(?:\\:)|(?:\\\\))+)")
|
||||
|
||||
// parseLine parses a line into an *Entry. It returns nil on comment lines or any other unparsable
|
||||
// line.
|
||||
func parseLine(line string) *Entry {
|
||||
const (
|
||||
tmpBackslash = "\r"
|
||||
tmpColon = "\n"
|
||||
)
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if strings.HasPrefix(line, "#") {
|
||||
return nil
|
||||
}
|
||||
|
||||
line = strings.Replace(line, `\\`, tmpBackslash, -1)
|
||||
line = strings.Replace(line, `\:`, tmpColon, -1)
|
||||
|
||||
parts := strings.Split(line, ":")
|
||||
if len(parts) != 5 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unescape escaped colons and backslashes
|
||||
for i := range parts {
|
||||
parts[i] = strings.Replace(parts[i], tmpBackslash, `\`, -1)
|
||||
parts[i] = strings.Replace(parts[i], tmpColon, `:`, -1)
|
||||
}
|
||||
|
||||
return &Entry{
|
||||
Hostname: parts[0],
|
||||
Port: parts[1],
|
||||
Database: parts[2],
|
||||
Username: parts[3],
|
||||
Password: parts[4],
|
||||
}
|
||||
}
|
||||
|
||||
// FindPassword finds the password for the provided hostname, port, database, and username. For a
|
||||
// Unix domain socket hostname must be set to "localhost". An empty string will be returned if no
|
||||
// match is found.
|
||||
//
|
||||
// See https://www.postgresql.org/docs/current/libpq-pgpass.html for more password file information.
|
||||
func (pf *Passfile) FindPassword(hostname, port, database, username string) (password string) {
|
||||
for _, e := range pf.Entries {
|
||||
if (e.Hostname == "*" || e.Hostname == hostname) &&
|
||||
(e.Port == "*" || e.Port == port) &&
|
||||
(e.Database == "*" || e.Database == database) &&
|
||||
(e.Username == "*" || e.Username == username) {
|
||||
return e.Password
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package pgpassfile
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func unescape(s string) string {
|
||||
s = strings.Replace(s, `\:`, `:`, -1)
|
||||
s = strings.Replace(s, `\\`, `\`, -1)
|
||||
return s
|
||||
}
|
||||
|
||||
var passfile = [][]string{
|
||||
{"test1", "5432", "larrydb", "larry", "whatstheidea"},
|
||||
{"test1", "5432", "moedb", "moe", "imbecile"},
|
||||
{"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"},
|
||||
{"test2", "5432", "*", "shemp", "heymoe"},
|
||||
{"test2", "5432", "*", "*", `test\\ing\:`},
|
||||
{"localhost", "*", "*", "*", "sesam"},
|
||||
{"test3", "*", "", "", "swordfish"}, // user will be filled later
|
||||
}
|
||||
|
||||
func TestParsePassFile(t *testing.T) {
|
||||
buf := bytes.NewBufferString(`# A comment
|
||||
test1:5432:larrydb:larry:whatstheidea
|
||||
test1:5432:moedb:moe:imbecile
|
||||
test1:5432:curlydb:curly:nyuknyuknyuk
|
||||
test2:5432:*:shemp:heymoe
|
||||
test2:5432:*:*:test\\ing\:
|
||||
localhost:*:*:*:sesam
|
||||
`)
|
||||
|
||||
passfile, err := ParsePassfile(buf)
|
||||
require.Nil(t, err)
|
||||
|
||||
assert.Len(t, passfile.Entries, 6)
|
||||
|
||||
assert.Equal(t, "whatstheidea", passfile.FindPassword("test1", "5432", "larrydb", "larry"))
|
||||
assert.Equal(t, "imbecile", passfile.FindPassword("test1", "5432", "moedb", "moe"))
|
||||
assert.Equal(t, `test\ing:`, passfile.FindPassword("test2", "5432", "something", "else"))
|
||||
assert.Equal(t, "sesam", passfile.FindPassword("localhost", "9999", "foo", "bare"))
|
||||
|
||||
assert.Equal(t, "", passfile.FindPassword("wrong", "5432", "larrydb", "larry"))
|
||||
assert.Equal(t, "", passfile.FindPassword("test1", "wrong", "larrydb", "larry"))
|
||||
assert.Equal(t, "", passfile.FindPassword("test1", "5432", "wrong", "larry"))
|
||||
assert.Equal(t, "", passfile.FindPassword("test1", "5432", "larrydb", "wrong"))
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthTypeOk = 0
|
||||
AuthTypeCleartextPassword = 3
|
||||
AuthTypeMD5Password = 5
|
||||
)
|
||||
|
||||
type Authentication struct {
|
||||
Type uint32
|
||||
|
||||
// MD5Password fields
|
||||
Salt [4]byte
|
||||
}
|
||||
|
||||
func (*Authentication) Backend() {}
|
||||
|
||||
func (dst *Authentication) Decode(src []byte) error {
|
||||
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
|
||||
|
||||
switch dst.Type {
|
||||
case AuthTypeOk:
|
||||
case AuthTypeCleartextPassword:
|
||||
case AuthTypeMD5Password:
|
||||
copy(dst.Salt[:], src[4:8])
|
||||
default:
|
||||
return errors.Errorf("unknown authentication type: %d", dst.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Authentication) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
dst = pgio.AppendUint32(dst, src.Type)
|
||||
|
||||
switch src.Type {
|
||||
case AuthTypeMD5Password:
|
||||
dst = append(dst, src.Salt[:]...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/chunkreader"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Backend struct {
|
||||
cr *chunkreader.ChunkReader
|
||||
w io.Writer
|
||||
|
||||
// Frontend message flyweights
|
||||
bind Bind
|
||||
_close Close
|
||||
copyFail CopyFail
|
||||
describe Describe
|
||||
execute Execute
|
||||
flush Flush
|
||||
parse Parse
|
||||
passwordMessage PasswordMessage
|
||||
query Query
|
||||
startupMessage StartupMessage
|
||||
sync Sync
|
||||
terminate Terminate
|
||||
|
||||
bodyLen int
|
||||
msgType byte
|
||||
partialMsg bool
|
||||
}
|
||||
|
||||
func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
|
||||
cr := chunkreader.NewChunkReader(r)
|
||||
return &Backend{cr: cr, w: w}, nil
|
||||
}
|
||||
|
||||
func (b *Backend) Send(msg BackendMessage) error {
|
||||
_, err := b.w.Write(msg.Encode(nil))
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
||||
buf, err := b.cr.Next(4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||
|
||||
buf, err = b.cr.Next(msgSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = b.startupMessage.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &b.startupMessage, nil
|
||||
}
|
||||
|
||||
func (b *Backend) Receive() (FrontendMessage, error) {
|
||||
if !b.partialMsg {
|
||||
header, err := b.cr.Next(5)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.msgType = header[0]
|
||||
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
b.partialMsg = true
|
||||
}
|
||||
|
||||
var msg FrontendMessage
|
||||
switch b.msgType {
|
||||
case 'B':
|
||||
msg = &b.bind
|
||||
case 'C':
|
||||
msg = &b._close
|
||||
case 'D':
|
||||
msg = &b.describe
|
||||
case 'E':
|
||||
msg = &b.execute
|
||||
case 'f':
|
||||
msg = &b.copyFail
|
||||
case 'H':
|
||||
msg = &b.flush
|
||||
case 'P':
|
||||
msg = &b.parse
|
||||
case 'p':
|
||||
msg = &b.passwordMessage
|
||||
case 'Q':
|
||||
msg = &b.query
|
||||
case 'S':
|
||||
msg = &b.sync
|
||||
case 'X':
|
||||
msg = &b.terminate
|
||||
default:
|
||||
return nil, errors.Errorf("unknown message type: %c", b.msgType)
|
||||
}
|
||||
|
||||
msgBody, err := b.cr.Next(b.bodyLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.partialMsg = false
|
||||
|
||||
err = msg.Decode(msgBody)
|
||||
return msg, err
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type BackendKeyData struct {
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}
|
||||
|
||||
func (*BackendKeyData) Backend() {}
|
||||
|
||||
func (dst *BackendKeyData) Decode(src []byte) error {
|
||||
if len(src) != 8 {
|
||||
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
|
||||
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, 12)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *BackendKeyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}{
|
||||
Type: "BackendKeyData",
|
||||
ProcessID: src.ProcessID,
|
||||
SecretKey: src.SecretKey,
|
||||
})
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package pgproto3_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
func TestBackendReceiveInterrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push([]byte{'Q', 0, 0, 0, 6})
|
||||
|
||||
backend, err := pgproto3.NewBackend(server, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msg, err := backend.Receive()
|
||||
if err == nil {
|
||||
t.Fatal("expected err")
|
||||
}
|
||||
if msg != nil {
|
||||
t.Fatalf("did not expect msg, but %v", msg)
|
||||
}
|
||||
|
||||
server.push([]byte{'I', 0})
|
||||
|
||||
msg, err = backend.Receive()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
|
||||
t.Fatalf("unexpected msg: %v", msg)
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type BigEndianBuf [8]byte
|
||||
|
||||
func (b BigEndianBuf) Int16(n int16) []byte {
|
||||
buf := b[0:2]
|
||||
binary.BigEndian.PutUint16(buf, uint16(n))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Uint16(n uint16) []byte {
|
||||
buf := b[0:2]
|
||||
binary.BigEndian.PutUint16(buf, n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Int32(n int32) []byte {
|
||||
buf := b[0:4]
|
||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Uint32(n uint32) []byte {
|
||||
buf := b[0:4]
|
||||
binary.BigEndian.PutUint32(buf, n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Int64(n int64) []byte {
|
||||
buf := b[0:8]
|
||||
binary.BigEndian.PutUint64(buf, uint64(n))
|
||||
return buf
|
||||
}
|
||||
@@ -1,171 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Bind struct {
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters [][]byte
|
||||
ResultFormatCodes []int16
|
||||
}
|
||||
|
||||
func (*Bind) Frontend() {}
|
||||
|
||||
func (dst *Bind) Decode(src []byte) error {
|
||||
*dst = Bind{}
|
||||
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
dst.DestinationPortal = string(src[:idx])
|
||||
rp := idx + 1
|
||||
|
||||
idx = bytes.IndexByte(src[rp:], 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
dst.PreparedStatement = string(src[rp : rp+idx])
|
||||
rp += idx + 1
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
if parameterFormatCodeCount > 0 {
|
||||
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
|
||||
|
||||
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
for i := 0; i < parameterFormatCodeCount; i++ {
|
||||
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
}
|
||||
}
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
if parameterCount > 0 {
|
||||
dst.Parameters = make([][]byte, parameterCount)
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
|
||||
// null
|
||||
if msgSize == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(src[rp:]) < msgSize {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
|
||||
dst.Parameters[i] = src[rp : rp+msgSize]
|
||||
rp += msgSize
|
||||
}
|
||||
}
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
|
||||
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
for i := 0; i < resultFormatCodeCount; i++ {
|
||||
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Bind) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'B')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.DestinationPortal...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.PreparedStatement...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
||||
for _, fc := range src.ParameterFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
||||
for _, p := range src.Parameters {
|
||||
if p == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
dst = pgio.AppendInt32(dst, int32(len(p)))
|
||||
dst = append(dst, p...)
|
||||
}
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
||||
for _, fc := range src.ResultFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Bind) MarshalJSON() ([]byte, error) {
|
||||
formattedParameters := make([]map[string]string, len(src.Parameters))
|
||||
for i, p := range src.Parameters {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if src.ParameterFormatCodes[i] == 0 {
|
||||
formattedParameters[i] = map[string]string{"text": string(p)}
|
||||
} else {
|
||||
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters []map[string]string
|
||||
ResultFormatCodes []int16
|
||||
}{
|
||||
Type: "Bind",
|
||||
DestinationPortal: src.DestinationPortal,
|
||||
PreparedStatement: src.PreparedStatement,
|
||||
ParameterFormatCodes: src.ParameterFormatCodes,
|
||||
Parameters: formattedParameters,
|
||||
ResultFormatCodes: src.ResultFormatCodes,
|
||||
})
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type BindComplete struct{}
|
||||
|
||||
func (*BindComplete) Backend() {}
|
||||
|
||||
func (dst *BindComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *BindComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '2', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *BindComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "BindComplete",
|
||||
})
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Close struct {
|
||||
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||
Name string
|
||||
}
|
||||
|
||||
func (*Close) Frontend() {}
|
||||
|
||||
func (dst *Close) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Close"}
|
||||
}
|
||||
|
||||
dst.ObjectType = src[0]
|
||||
rp := 1
|
||||
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
if idx != len(src[rp:])-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Close"}
|
||||
}
|
||||
|
||||
dst.Name = string(src[rp : len(src)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Close) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Close) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ObjectType string
|
||||
Name string
|
||||
}{
|
||||
Type: "Close",
|
||||
ObjectType: string(src.ObjectType),
|
||||
Name: src.Name,
|
||||
})
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CloseComplete struct{}
|
||||
|
||||
func (*CloseComplete) Backend() {}
|
||||
|
||||
func (dst *CloseComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CloseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '3', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *CloseComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "CloseComplete",
|
||||
})
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CommandComplete struct {
|
||||
CommandTag string
|
||||
}
|
||||
|
||||
func (*CommandComplete) Backend() {}
|
||||
|
||||
func (dst *CommandComplete) Decode(src []byte) error {
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "CommandComplete"}
|
||||
}
|
||||
|
||||
dst.CommandTag = string(src[:idx])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CommandComplete) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.CommandTag...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CommandComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
CommandTag string
|
||||
}{
|
||||
Type: "CommandComplete",
|
||||
CommandTag: src.CommandTag,
|
||||
})
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CopyBothResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
func (*CopyBothResponse) Backend() {}
|
||||
|
||||
func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'W')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CopyBothResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyBothResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CopyData struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (*CopyData) Backend() {}
|
||||
func (*CopyData) Frontend() {}
|
||||
|
||||
func (dst *CopyData) Decode(src []byte) error {
|
||||
dst.Data = src
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'd')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||
dst = append(dst, src.Data...)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CopyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data string
|
||||
}{
|
||||
Type: "CopyData",
|
||||
Data: hex.EncodeToString(src.Data),
|
||||
})
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CopyDone struct {
|
||||
}
|
||||
|
||||
func (*CopyDone) Backend() {}
|
||||
|
||||
func (dst *CopyDone) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyDone) Encode(dst []byte) []byte {
|
||||
return append(dst, 'c', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *CopyDone) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "CopyDone",
|
||||
})
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CopyFail struct {
|
||||
Error string
|
||||
}
|
||||
|
||||
func (*CopyFail) Frontend() {}
|
||||
func (*CopyFail) Backend() {}
|
||||
|
||||
func (dst *CopyFail) Decode(src []byte) error {
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyFail"}
|
||||
}
|
||||
|
||||
dst.Error = string(src[:idx])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyFail) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Error...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CopyFail) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Error string
|
||||
}{
|
||||
Type: "CopyFail",
|
||||
Error: src.Error,
|
||||
})
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CopyInResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
func (*CopyInResponse) Backend() {}
|
||||
|
||||
func (dst *CopyInResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyInResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'G')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CopyInResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyInResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type CopyOutResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
func (*CopyOutResponse) Backend() {}
|
||||
|
||||
func (dst *CopyOutResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'H')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *CopyOutResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyOutResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type DataRow struct {
|
||||
Values [][]byte
|
||||
}
|
||||
|
||||
func (*DataRow) Backend() {}
|
||||
|
||||
func (dst *DataRow) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
rp := 0
|
||||
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
// If the capacity of the values slice is too small OR substantially too
|
||||
// large reallocate. This is too avoid one row with many columns from
|
||||
// permanently allocating memory.
|
||||
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
|
||||
newCap := 32
|
||||
if newCap < fieldCount {
|
||||
newCap = fieldCount
|
||||
}
|
||||
dst.Values = make([][]byte, fieldCount, newCap)
|
||||
} else {
|
||||
dst.Values = dst.Values[:fieldCount]
|
||||
}
|
||||
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
|
||||
// null
|
||||
if msgSize == -1 {
|
||||
dst.Values[i] = nil
|
||||
} else {
|
||||
if len(src[rp:]) < msgSize {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
dst.Values[i] = src[rp : rp+msgSize]
|
||||
rp += msgSize
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *DataRow) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
||||
for _, v := range src.Values {
|
||||
if v == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
dst = pgio.AppendInt32(dst, int32(len(v)))
|
||||
dst = append(dst, v...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *DataRow) MarshalJSON() ([]byte, error) {
|
||||
formattedValues := make([]map[string]string, len(src.Values))
|
||||
for i, v := range src.Values {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var hasNonPrintable bool
|
||||
for _, b := range v {
|
||||
if b < 32 {
|
||||
hasNonPrintable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasNonPrintable {
|
||||
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
|
||||
} else {
|
||||
formattedValues[i] = map[string]string{"text": string(v)}
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Values []map[string]string
|
||||
}{
|
||||
Type: "DataRow",
|
||||
Values: formattedValues,
|
||||
})
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Describe struct {
|
||||
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||
Name string
|
||||
}
|
||||
|
||||
func (*Describe) Frontend() {}
|
||||
|
||||
func (dst *Describe) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Describe"}
|
||||
}
|
||||
|
||||
dst.ObjectType = src[0]
|
||||
rp := 1
|
||||
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
if idx != len(src[rp:])-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Describe"}
|
||||
}
|
||||
|
||||
dst.Name = string(src[rp : len(src)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Describe) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Describe) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ObjectType string
|
||||
Name string
|
||||
}{
|
||||
Type: "Describe",
|
||||
ObjectType: string(src.ObjectType),
|
||||
Name: src.Name,
|
||||
})
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type EmptyQueryResponse struct{}
|
||||
|
||||
func (*EmptyQueryResponse) Backend() {}
|
||||
|
||||
func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, 'I', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "EmptyQueryResponse",
|
||||
})
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type ErrorResponse struct {
|
||||
Severity string
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
|
||||
UnknownFields map[byte]string
|
||||
}
|
||||
|
||||
func (*ErrorResponse) Backend() {}
|
||||
|
||||
func (dst *ErrorResponse) Decode(src []byte) error {
|
||||
*dst = ErrorResponse{}
|
||||
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
for {
|
||||
k, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if k == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
vb, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := string(vb[:len(vb)-1])
|
||||
|
||||
switch k {
|
||||
case 'S':
|
||||
dst.Severity = v
|
||||
case 'C':
|
||||
dst.Code = v
|
||||
case 'M':
|
||||
dst.Message = v
|
||||
case 'D':
|
||||
dst.Detail = v
|
||||
case 'H':
|
||||
dst.Hint = v
|
||||
case 'P':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.Position = int32(n)
|
||||
case 'p':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.InternalPosition = int32(n)
|
||||
case 'q':
|
||||
dst.InternalQuery = v
|
||||
case 'W':
|
||||
dst.Where = v
|
||||
case 's':
|
||||
dst.SchemaName = v
|
||||
case 't':
|
||||
dst.TableName = v
|
||||
case 'c':
|
||||
dst.ColumnName = v
|
||||
case 'd':
|
||||
dst.DataTypeName = v
|
||||
case 'n':
|
||||
dst.ConstraintName = v
|
||||
case 'F':
|
||||
dst.File = v
|
||||
case 'L':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.Line = int32(n)
|
||||
case 'R':
|
||||
dst.Routine = v
|
||||
|
||||
default:
|
||||
if dst.UnknownFields == nil {
|
||||
dst.UnknownFields = make(map[byte]string)
|
||||
}
|
||||
dst.UnknownFields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ErrorResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, src.marshalBinary('E')...)
|
||||
}
|
||||
|
||||
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte(typeByte)
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
|
||||
if src.Severity != "" {
|
||||
buf.WriteByte('S')
|
||||
buf.WriteString(src.Severity)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Code != "" {
|
||||
buf.WriteByte('C')
|
||||
buf.WriteString(src.Code)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Message != "" {
|
||||
buf.WriteByte('M')
|
||||
buf.WriteString(src.Message)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Detail != "" {
|
||||
buf.WriteByte('D')
|
||||
buf.WriteString(src.Detail)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Hint != "" {
|
||||
buf.WriteByte('H')
|
||||
buf.WriteString(src.Hint)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Position != 0 {
|
||||
buf.WriteByte('P')
|
||||
buf.WriteString(strconv.Itoa(int(src.Position)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.InternalPosition != 0 {
|
||||
buf.WriteByte('p')
|
||||
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.InternalQuery != "" {
|
||||
buf.WriteByte('q')
|
||||
buf.WriteString(src.InternalQuery)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Where != "" {
|
||||
buf.WriteByte('W')
|
||||
buf.WriteString(src.Where)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.SchemaName != "" {
|
||||
buf.WriteByte('s')
|
||||
buf.WriteString(src.SchemaName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.TableName != "" {
|
||||
buf.WriteByte('t')
|
||||
buf.WriteString(src.TableName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.ColumnName != "" {
|
||||
buf.WriteByte('c')
|
||||
buf.WriteString(src.ColumnName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.DataTypeName != "" {
|
||||
buf.WriteByte('d')
|
||||
buf.WriteString(src.DataTypeName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.ConstraintName != "" {
|
||||
buf.WriteByte('n')
|
||||
buf.WriteString(src.ConstraintName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.File != "" {
|
||||
buf.WriteByte('F')
|
||||
buf.WriteString(src.File)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Line != 0 {
|
||||
buf.WriteByte('L')
|
||||
buf.WriteString(strconv.Itoa(int(src.Line)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Routine != "" {
|
||||
buf.WriteByte('R')
|
||||
buf.WriteString(src.Routine)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
|
||||
for k, v := range src.UnknownFields {
|
||||
buf.WriteByte(k)
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(v)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
buf.WriteByte(0)
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Execute struct {
|
||||
Portal string
|
||||
MaxRows uint32
|
||||
}
|
||||
|
||||
func (*Execute) Frontend() {}
|
||||
|
||||
func (dst *Execute) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Portal = string(b[:len(b)-1])
|
||||
|
||||
if buf.Len() < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Execute"}
|
||||
}
|
||||
dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Execute) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'E')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Portal...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint32(dst, src.MaxRows)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Execute) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Portal string
|
||||
MaxRows uint32
|
||||
}{
|
||||
Type: "Execute",
|
||||
Portal: src.Portal,
|
||||
MaxRows: src.MaxRows,
|
||||
})
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Flush struct{}
|
||||
|
||||
func (*Flush) Frontend() {}
|
||||
|
||||
func (dst *Flush) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Flush) Encode(dst []byte) []byte {
|
||||
return append(dst, 'H', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *Flush) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "Flush",
|
||||
})
|
||||
}
|
||||
@@ -1,128 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/chunkreader"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Frontend struct {
|
||||
cr *chunkreader.ChunkReader
|
||||
w io.Writer
|
||||
|
||||
// Backend message flyweights
|
||||
authentication Authentication
|
||||
backendKeyData BackendKeyData
|
||||
bindComplete BindComplete
|
||||
closeComplete CloseComplete
|
||||
commandComplete CommandComplete
|
||||
copyBothResponse CopyBothResponse
|
||||
copyData CopyData
|
||||
copyInResponse CopyInResponse
|
||||
copyOutResponse CopyOutResponse
|
||||
copyDone CopyDone
|
||||
copyFail CopyFail
|
||||
dataRow DataRow
|
||||
emptyQueryResponse EmptyQueryResponse
|
||||
errorResponse ErrorResponse
|
||||
functionCallResponse FunctionCallResponse
|
||||
noData NoData
|
||||
noticeResponse NoticeResponse
|
||||
notificationResponse NotificationResponse
|
||||
parameterDescription ParameterDescription
|
||||
parameterStatus ParameterStatus
|
||||
parseComplete ParseComplete
|
||||
readyForQuery ReadyForQuery
|
||||
rowDescription RowDescription
|
||||
|
||||
bodyLen int
|
||||
msgType byte
|
||||
partialMsg bool
|
||||
}
|
||||
|
||||
func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
|
||||
cr := chunkreader.NewChunkReader(r)
|
||||
return &Frontend{cr: cr, w: w}, nil
|
||||
}
|
||||
|
||||
func (b *Frontend) Send(msg FrontendMessage) error {
|
||||
_, err := b.w.Write(msg.Encode(nil))
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Frontend) Receive() (BackendMessage, error) {
|
||||
if !b.partialMsg {
|
||||
header, err := b.cr.Next(5)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.msgType = header[0]
|
||||
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
b.partialMsg = true
|
||||
}
|
||||
|
||||
var msg BackendMessage
|
||||
switch b.msgType {
|
||||
case '1':
|
||||
msg = &b.parseComplete
|
||||
case '2':
|
||||
msg = &b.bindComplete
|
||||
case '3':
|
||||
msg = &b.closeComplete
|
||||
case 'A':
|
||||
msg = &b.notificationResponse
|
||||
case 'c':
|
||||
msg = &b.copyDone
|
||||
case 'C':
|
||||
msg = &b.commandComplete
|
||||
case 'd':
|
||||
msg = &b.copyData
|
||||
case 'D':
|
||||
msg = &b.dataRow
|
||||
case 'E':
|
||||
msg = &b.errorResponse
|
||||
case 'f':
|
||||
msg = &b.copyFail
|
||||
case 'G':
|
||||
msg = &b.copyInResponse
|
||||
case 'H':
|
||||
msg = &b.copyOutResponse
|
||||
case 'I':
|
||||
msg = &b.emptyQueryResponse
|
||||
case 'K':
|
||||
msg = &b.backendKeyData
|
||||
case 'n':
|
||||
msg = &b.noData
|
||||
case 'N':
|
||||
msg = &b.noticeResponse
|
||||
case 'R':
|
||||
msg = &b.authentication
|
||||
case 'S':
|
||||
msg = &b.parameterStatus
|
||||
case 't':
|
||||
msg = &b.parameterDescription
|
||||
case 'T':
|
||||
msg = &b.rowDescription
|
||||
case 'V':
|
||||
msg = &b.functionCallResponse
|
||||
case 'W':
|
||||
msg = &b.copyBothResponse
|
||||
case 'Z':
|
||||
msg = &b.readyForQuery
|
||||
default:
|
||||
return nil, errors.Errorf("unknown message type: %c", b.msgType)
|
||||
}
|
||||
|
||||
msgBody, err := b.cr.Next(b.bodyLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.partialMsg = false
|
||||
|
||||
err = msg.Decode(msgBody)
|
||||
return msg, err
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package pgproto3_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
type interruptReader struct {
|
||||
chunks [][]byte
|
||||
}
|
||||
|
||||
func (ir *interruptReader) Read(p []byte) (n int, err error) {
|
||||
if len(ir.chunks) == 0 {
|
||||
return 0, errors.New("no data")
|
||||
}
|
||||
|
||||
n = copy(p, ir.chunks[0])
|
||||
if n != len(ir.chunks[0]) {
|
||||
panic("this test reader doesn't support partial reads of chunks")
|
||||
}
|
||||
|
||||
ir.chunks = ir.chunks[1:]
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (ir *interruptReader) push(p []byte) {
|
||||
ir.chunks = append(ir.chunks, p)
|
||||
}
|
||||
|
||||
func TestFrontendReceiveInterrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push([]byte{'Z', 0, 0, 0, 5})
|
||||
|
||||
frontend, err := pgproto3.NewFrontend(server, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msg, err := frontend.Receive()
|
||||
if err == nil {
|
||||
t.Fatal("expected err")
|
||||
}
|
||||
if msg != nil {
|
||||
t.Fatalf("did not expect msg, but %v", msg)
|
||||
}
|
||||
|
||||
server.push([]byte{'I'})
|
||||
|
||||
msg, err = frontend.Receive()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' {
|
||||
t.Fatalf("unexpected msg: %v", msg)
|
||||
}
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type FunctionCallResponse struct {
|
||||
Result []byte
|
||||
}
|
||||
|
||||
func (*FunctionCallResponse) Backend() {}
|
||||
|
||||
func (dst *FunctionCallResponse) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
rp := 0
|
||||
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
if resultSize == -1 {
|
||||
dst.Result = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src[rp:]) != resultSize {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
|
||||
dst.Result = src[rp:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'V')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
if src.Result == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
} else {
|
||||
dst = pgio.AppendInt32(dst, int32(len(src.Result)))
|
||||
dst = append(dst, src.Result...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) {
|
||||
var formattedValue map[string]string
|
||||
var hasNonPrintable bool
|
||||
for _, b := range src.Result {
|
||||
if b < 32 {
|
||||
hasNonPrintable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasNonPrintable {
|
||||
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
|
||||
} else {
|
||||
formattedValue = map[string]string{"text": string(src.Result)}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Result map[string]string
|
||||
}{
|
||||
Type: "FunctionCallResponse",
|
||||
Result: formattedValue,
|
||||
})
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type NoData struct{}
|
||||
|
||||
func (*NoData) Backend() {}
|
||||
|
||||
func (dst *NoData) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *NoData) Encode(dst []byte) []byte {
|
||||
return append(dst, 'n', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *NoData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "NoData",
|
||||
})
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
type NoticeResponse ErrorResponse
|
||||
|
||||
func (*NoticeResponse) Backend() {}
|
||||
|
||||
func (dst *NoticeResponse) Decode(src []byte) error {
|
||||
return (*ErrorResponse)(dst).Decode(src)
|
||||
}
|
||||
|
||||
func (src *NoticeResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type NotificationResponse struct {
|
||||
PID uint32
|
||||
Channel string
|
||||
Payload string
|
||||
}
|
||||
|
||||
func (*NotificationResponse) Backend() {}
|
||||
|
||||
func (dst *NotificationResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
pid := binary.BigEndian.Uint32(buf.Next(4))
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
channel := string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := string(b[:len(b)-1])
|
||||
|
||||
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *NotificationResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'A')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Channel...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Payload...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *NotificationResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
PID uint32
|
||||
Channel string
|
||||
Payload string
|
||||
}{
|
||||
Type: "NotificationResponse",
|
||||
PID: src.PID,
|
||||
Channel: src.Channel,
|
||||
Payload: src.Payload,
|
||||
})
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type ParameterDescription struct {
|
||||
ParameterOIDs []uint32
|
||||
}
|
||||
|
||||
func (*ParameterDescription) Backend() {}
|
||||
|
||||
func (dst *ParameterDescription) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
|
||||
}
|
||||
|
||||
// Reported parameter count will be incorrect when number of args is greater than uint16
|
||||
buf.Next(2)
|
||||
// Instead infer parameter count by remaining size of message
|
||||
parameterCount := buf.Len() / 4
|
||||
|
||||
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ParameterDescription) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 't')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||
for _, oid := range src.ParameterOIDs {
|
||||
dst = pgio.AppendUint32(dst, oid)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *ParameterDescription) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ParameterOIDs []uint32
|
||||
}{
|
||||
Type: "ParameterDescription",
|
||||
ParameterOIDs: src.ParameterOIDs,
|
||||
})
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type ParameterStatus struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (*ParameterStatus) Backend() {}
|
||||
|
||||
func (dst *ParameterStatus) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name := string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value := string(b[:len(b)-1])
|
||||
|
||||
*dst = ParameterStatus{Name: name, Value: value}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ParameterStatus) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'S')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Value...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (ps *ParameterStatus) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Name string
|
||||
Value string
|
||||
}{
|
||||
Type: "ParameterStatus",
|
||||
Name: ps.Name,
|
||||
Value: ps.Value,
|
||||
})
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Parse struct {
|
||||
Name string
|
||||
Query string
|
||||
ParameterOIDs []uint32
|
||||
}
|
||||
|
||||
func (*Parse) Frontend() {}
|
||||
|
||||
func (dst *Parse) Decode(src []byte) error {
|
||||
*dst = Parse{}
|
||||
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Name = string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Query = string(b[:len(b)-1])
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||
}
|
||||
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
for i := 0; i < parameterOIDCount; i++ {
|
||||
if buf.Len() < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||
}
|
||||
dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Parse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'P')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Query...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||
for _, oid := range src.ParameterOIDs {
|
||||
dst = pgio.AppendUint32(dst, oid)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Parse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Name string
|
||||
Query string
|
||||
ParameterOIDs []uint32
|
||||
}{
|
||||
Type: "Parse",
|
||||
Name: src.Name,
|
||||
Query: src.Query,
|
||||
ParameterOIDs: src.ParameterOIDs,
|
||||
})
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ParseComplete struct{}
|
||||
|
||||
func (*ParseComplete) Backend() {}
|
||||
|
||||
func (dst *ParseComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ParseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '1', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *ParseComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "ParseComplete",
|
||||
})
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type PasswordMessage struct {
|
||||
Password string
|
||||
}
|
||||
|
||||
func (*PasswordMessage) Frontend() {}
|
||||
|
||||
func (dst *PasswordMessage) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Password = string(b[:len(b)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *PasswordMessage) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
|
||||
|
||||
dst = append(dst, src.Password...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *PasswordMessage) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Password string
|
||||
}{
|
||||
Type: "PasswordMessage",
|
||||
Password: src.Password,
|
||||
})
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Message is the interface implemented by an object that can decode and encode
|
||||
// a particular PostgreSQL message.
|
||||
type Message interface {
|
||||
// Decode is allowed and expected to retain a reference to data after
|
||||
// returning (unlike encoding.BinaryUnmarshaler).
|
||||
Decode(data []byte) error
|
||||
|
||||
// Encode appends itself to dst and returns the new buffer.
|
||||
Encode(dst []byte) []byte
|
||||
}
|
||||
|
||||
type FrontendMessage interface {
|
||||
Message
|
||||
Frontend() // no-op method to distinguish frontend from backend methods
|
||||
}
|
||||
|
||||
type BackendMessage interface {
|
||||
Message
|
||||
Backend() // no-op method to distinguish frontend from backend methods
|
||||
}
|
||||
|
||||
type invalidMessageLenErr struct {
|
||||
messageType string
|
||||
expectedLen int
|
||||
actualLen int
|
||||
}
|
||||
|
||||
func (e *invalidMessageLenErr) Error() string {
|
||||
return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen)
|
||||
}
|
||||
|
||||
type invalidMessageFormatErr struct {
|
||||
messageType string
|
||||
}
|
||||
|
||||
func (e *invalidMessageFormatErr) Error() string {
|
||||
return fmt.Sprintf("%s body is invalid", e.messageType)
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Query struct {
|
||||
String string
|
||||
}
|
||||
|
||||
func (*Query) Frontend() {}
|
||||
|
||||
func (dst *Query) Decode(src []byte) error {
|
||||
i := bytes.IndexByte(src, 0)
|
||||
if i != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Query"}
|
||||
}
|
||||
|
||||
dst.String = string(src[:i])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Query) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'Q')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
|
||||
|
||||
dst = append(dst, src.String...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *Query) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
String string
|
||||
}{
|
||||
Type: "Query",
|
||||
String: src.String,
|
||||
})
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ReadyForQuery struct {
|
||||
TxStatus byte
|
||||
}
|
||||
|
||||
func (*ReadyForQuery) Backend() {}
|
||||
|
||||
func (dst *ReadyForQuery) Decode(src []byte) error {
|
||||
if len(src) != 1 {
|
||||
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.TxStatus = src[0]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *ReadyForQuery) Encode(dst []byte) []byte {
|
||||
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
|
||||
}
|
||||
|
||||
func (src *ReadyForQuery) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
TxStatus string
|
||||
}{
|
||||
Type: "ReadyForQuery",
|
||||
TxStatus: string(src.TxStatus),
|
||||
})
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
const (
|
||||
TextFormat = 0
|
||||
BinaryFormat = 1
|
||||
)
|
||||
|
||||
type FieldDescription struct {
|
||||
Name string
|
||||
TableOID uint32
|
||||
TableAttributeNumber uint16
|
||||
DataTypeOID uint32
|
||||
DataTypeSize int16
|
||||
TypeModifier int32
|
||||
Format int16
|
||||
}
|
||||
|
||||
type RowDescription struct {
|
||||
Fields []FieldDescription
|
||||
}
|
||||
|
||||
func (*RowDescription) Backend() {}
|
||||
|
||||
func (dst *RowDescription) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||
}
|
||||
fieldCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
dst.Fields = dst.Fields[0:0]
|
||||
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
var fd FieldDescription
|
||||
bName, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fd.Name = string(bName[:len(bName)-1])
|
||||
|
||||
// Since buf.Next() doesn't return an error if we hit the end of the buffer
|
||||
// check Len ahead of time
|
||||
if buf.Len() < 18 {
|
||||
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||
}
|
||||
|
||||
fd.TableOID = binary.BigEndian.Uint32(buf.Next(4))
|
||||
fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2))
|
||||
fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4))
|
||||
fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
fd.TypeModifier = int32(binary.BigEndian.Uint32(buf.Next(4)))
|
||||
fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
dst.Fields = append(dst.Fields, fd)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *RowDescription) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'T')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
|
||||
for _, fd := range src.Fields {
|
||||
dst = append(dst, fd.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint32(dst, fd.TableOID)
|
||||
dst = pgio.AppendUint16(dst, fd.TableAttributeNumber)
|
||||
dst = pgio.AppendUint32(dst, fd.DataTypeOID)
|
||||
dst = pgio.AppendInt16(dst, fd.DataTypeSize)
|
||||
dst = pgio.AppendInt32(dst, fd.TypeModifier)
|
||||
dst = pgio.AppendInt16(dst, fd.Format)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *RowDescription) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Fields []FieldDescription
|
||||
}{
|
||||
Type: "RowDescription",
|
||||
Fields: src.Fields,
|
||||
})
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolVersionNumber = 196608 // 3.0
|
||||
sslRequestNumber = 80877103
|
||||
)
|
||||
|
||||
type StartupMessage struct {
|
||||
ProtocolVersion uint32
|
||||
Parameters map[string]string
|
||||
}
|
||||
|
||||
func (*StartupMessage) Frontend() {}
|
||||
|
||||
func (dst *StartupMessage) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.Errorf("startup message too short")
|
||||
}
|
||||
|
||||
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
||||
rp := 4
|
||||
|
||||
if dst.ProtocolVersion == sslRequestNumber {
|
||||
return errors.Errorf("can't handle ssl connection request")
|
||||
}
|
||||
|
||||
if dst.ProtocolVersion != ProtocolVersionNumber {
|
||||
return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
|
||||
}
|
||||
|
||||
dst.Parameters = make(map[string]string)
|
||||
for {
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||
}
|
||||
key := string(src[rp : rp+idx])
|
||||
rp += idx + 1
|
||||
|
||||
idx = bytes.IndexByte(src[rp:], 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||
}
|
||||
value := string(src[rp : rp+idx])
|
||||
rp += idx + 1
|
||||
|
||||
dst.Parameters[key] = value
|
||||
|
||||
if len(src[rp:]) == 1 {
|
||||
if src[rp] != 0 {
|
||||
return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *StartupMessage) Encode(dst []byte) []byte {
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint32(dst, src.ProtocolVersion)
|
||||
for k, v := range src.Parameters {
|
||||
dst = append(dst, k...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, v...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *StartupMessage) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProtocolVersion uint32
|
||||
Parameters map[string]string
|
||||
}{
|
||||
Type: "StartupMessage",
|
||||
ProtocolVersion: src.ProtocolVersion,
|
||||
Parameters: src.Parameters,
|
||||
})
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Sync struct{}
|
||||
|
||||
func (*Sync) Frontend() {}
|
||||
|
||||
func (dst *Sync) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Sync) Encode(dst []byte) []byte {
|
||||
return append(dst, 'S', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *Sync) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "Sync",
|
||||
})
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Terminate struct{}
|
||||
|
||||
func (*Terminate) Frontend() {}
|
||||
|
||||
func (dst *Terminate) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Terminate) Encode(dst []byte) []byte {
|
||||
return append(dst, 'X', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
func (src *Terminate) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "Terminate",
|
||||
})
|
||||
}
|
||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -3,7 +3,7 @@ package pgtype
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -6,7 +6,7 @@ import (
|
||||
"math"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -6,7 +6,7 @@ import (
|
||||
"math"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
// Hstore represents an hstore column that can be null or have null values
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -6,7 +6,7 @@ import (
|
||||
"math"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -7,7 +7,7 @@ import (
|
||||
"math"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -3,7 +3,7 @@ package pgtype
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -7,7 +7,7 @@ import (
|
||||
"math"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -3,7 +3,7 @@ package pgtype
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user