6972a57421
It was a mistake to use it in other contexts. This made interop difficult between pacakges that depended on pgtype such as pgx and packages that did not like pgconn and pgproto3. In particular this was awkward for prepared statements. This is preparation for removing pgx.PreparedStatement in favor of pgconn.PreparedStatement.
602 lines
14 KiB
Go
602 lines
14 KiB
Go
package pgmock
|
|
|
|
import (
|
|
"io"
|
|
"net"
|
|
"reflect"
|
|
|
|
errors "golang.org/x/xerrors"
|
|
|
|
"github.com/jackc/pgproto3/v2"
|
|
"github.com/jackc/pgtype"
|
|
)
|
|
|
|
type Server struct {
|
|
ln net.Listener
|
|
controller Controller
|
|
}
|
|
|
|
func NewServer(controller Controller) (*Server, error) {
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
server := &Server{
|
|
ln: ln,
|
|
controller: controller,
|
|
}
|
|
|
|
return server, nil
|
|
}
|
|
|
|
func (s *Server) Addr() net.Addr {
|
|
return s.ln.Addr()
|
|
}
|
|
|
|
func (s *Server) ServeOne() error {
|
|
conn, err := s.ln.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
|
|
s.Close()
|
|
|
|
backend, err := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)
|
|
if err != nil {
|
|
conn.Close()
|
|
return err
|
|
}
|
|
|
|
return s.controller.Serve(backend)
|
|
}
|
|
|
|
func (s *Server) Close() error {
|
|
err := s.ln.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type Controller interface {
|
|
Serve(backend *pgproto3.Backend) error
|
|
}
|
|
|
|
type Step interface {
|
|
Step(*pgproto3.Backend) error
|
|
}
|
|
|
|
type Script struct {
|
|
Steps []Step
|
|
}
|
|
|
|
func (s *Script) Run(backend *pgproto3.Backend) error {
|
|
for _, step := range s.Steps {
|
|
err := step.Step(backend)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Script) Serve(backend *pgproto3.Backend) error {
|
|
for _, step := range s.Steps {
|
|
err := step.Step(backend)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Script) Step(backend *pgproto3.Backend) error {
|
|
return s.Serve(backend)
|
|
}
|
|
|
|
type expectMessageStep struct {
|
|
want pgproto3.FrontendMessage
|
|
any bool
|
|
}
|
|
|
|
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
|
|
msg, err := backend.Receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
|
|
return nil
|
|
}
|
|
|
|
if !reflect.DeepEqual(msg, e.want) {
|
|
return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type expectStartupMessageStep struct {
|
|
want *pgproto3.StartupMessage
|
|
any bool
|
|
}
|
|
|
|
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
|
|
msg, err := backend.ReceiveStartupMessage()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if e.any {
|
|
return nil
|
|
}
|
|
|
|
if !reflect.DeepEqual(msg, e.want) {
|
|
return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ExpectMessage(want pgproto3.FrontendMessage) Step {
|
|
return expectMessage(want, false)
|
|
}
|
|
|
|
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
|
|
return expectMessage(want, true)
|
|
}
|
|
|
|
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
|
|
if want, ok := want.(*pgproto3.StartupMessage); ok {
|
|
return &expectStartupMessageStep{want: want, any: any}
|
|
}
|
|
|
|
return &expectMessageStep{want: want, any: any}
|
|
}
|
|
|
|
type sendMessageStep struct {
|
|
msg pgproto3.BackendMessage
|
|
}
|
|
|
|
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
|
return backend.Send(e.msg)
|
|
}
|
|
|
|
func SendMessage(msg pgproto3.BackendMessage) Step {
|
|
return &sendMessageStep{msg: msg}
|
|
}
|
|
|
|
type waitForCloseMessageStep struct{}
|
|
|
|
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
|
|
for {
|
|
msg, err := backend.Receive()
|
|
if err == io.EOF {
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, ok := msg.(*pgproto3.Terminate); ok {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func WaitForClose() Step {
|
|
return &waitForCloseMessageStep{}
|
|
}
|
|
|
|
func AcceptUnauthenticatedConnRequestSteps() []Step {
|
|
return []Step{
|
|
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
|
SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}),
|
|
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
|
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
}
|
|
}
|
|
|
|
func PgxInitSteps() []Step {
|
|
steps := []Step{
|
|
ExpectMessage(&pgproto3.Parse{
|
|
Query: `select t.oid,
|
|
case when nsp.nspname in ('pg_catalog', 'public') then t.typname
|
|
else nsp.nspname||'.'||t.typname
|
|
end
|
|
from pg_type t
|
|
left join pg_type base_type on t.typelem=base_type.oid
|
|
left join pg_namespace nsp on t.typnamespace=nsp.oid
|
|
where (
|
|
t.typtype in('b', 'p', 'r', 'e')
|
|
and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))
|
|
)`,
|
|
}),
|
|
ExpectMessage(&pgproto3.Describe{
|
|
ObjectType: 'S',
|
|
}),
|
|
ExpectMessage(&pgproto3.Sync{}),
|
|
SendMessage(&pgproto3.ParseComplete{}),
|
|
SendMessage(&pgproto3.ParameterDescription{}),
|
|
SendMessage(&pgproto3.RowDescription{
|
|
Fields: []pgproto3.FieldDescription{
|
|
{Name: []byte("oid"),
|
|
TableOID: 1247,
|
|
TableAttributeNumber: 65534,
|
|
DataTypeOID: 26,
|
|
DataTypeSize: 4,
|
|
TypeModifier: -1,
|
|
Format: 0,
|
|
},
|
|
{Name: []byte("typname"),
|
|
TableOID: 1247,
|
|
TableAttributeNumber: 1,
|
|
DataTypeOID: 19,
|
|
DataTypeSize: 64,
|
|
TypeModifier: -1,
|
|
Format: 0,
|
|
},
|
|
},
|
|
}),
|
|
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
ExpectMessage(&pgproto3.Bind{
|
|
ResultFormatCodes: []int16{1, 1},
|
|
}),
|
|
ExpectMessage(&pgproto3.Execute{}),
|
|
ExpectMessage(&pgproto3.Sync{}),
|
|
SendMessage(&pgproto3.BindComplete{}),
|
|
}
|
|
|
|
rowVals := []struct {
|
|
oid uint32
|
|
name string
|
|
}{
|
|
{16, "bool"},
|
|
{17, "bytea"},
|
|
{18, "char"},
|
|
{19, "name"},
|
|
{20, "int8"},
|
|
{21, "int2"},
|
|
{22, "int2vector"},
|
|
{23, "int4"},
|
|
{24, "regproc"},
|
|
{25, "text"},
|
|
{26, "oid"},
|
|
{27, "tid"},
|
|
{28, "xid"},
|
|
{29, "cid"},
|
|
{30, "oidvector"},
|
|
{114, "json"},
|
|
{142, "xml"},
|
|
{143, "_xml"},
|
|
{199, "_json"},
|
|
{194, "pg_node_tree"},
|
|
{32, "pg_ddl_command"},
|
|
{210, "smgr"},
|
|
{600, "point"},
|
|
{601, "lseg"},
|
|
{602, "path"},
|
|
{603, "box"},
|
|
{604, "polygon"},
|
|
{628, "line"},
|
|
{629, "_line"},
|
|
{700, "float4"},
|
|
{701, "float8"},
|
|
{702, "abstime"},
|
|
{703, "reltime"},
|
|
{704, "tinterval"},
|
|
{705, "unknown"},
|
|
{718, "circle"},
|
|
{719, "_circle"},
|
|
{790, "money"},
|
|
{791, "_money"},
|
|
{829, "macaddr"},
|
|
{869, "inet"},
|
|
{650, "cidr"},
|
|
{1000, "_bool"},
|
|
{1001, "_bytea"},
|
|
{1002, "_char"},
|
|
{1003, "_name"},
|
|
{1005, "_int2"},
|
|
{1006, "_int2vector"},
|
|
{1007, "_int4"},
|
|
{1008, "_regproc"},
|
|
{1009, "_text"},
|
|
{1028, "_oid"},
|
|
{1010, "_tid"},
|
|
{1011, "_xid"},
|
|
{1012, "_cid"},
|
|
{1013, "_oidvector"},
|
|
{1014, "_bpchar"},
|
|
{1015, "_varchar"},
|
|
{1016, "_int8"},
|
|
{1017, "_point"},
|
|
{1018, "_lseg"},
|
|
{1019, "_path"},
|
|
{1020, "_box"},
|
|
{1021, "_float4"},
|
|
{1022, "_float8"},
|
|
{1023, "_abstime"},
|
|
{1024, "_reltime"},
|
|
{1025, "_tinterval"},
|
|
{1027, "_polygon"},
|
|
{1033, "aclitem"},
|
|
{1034, "_aclitem"},
|
|
{1040, "_macaddr"},
|
|
{1041, "_inet"},
|
|
{651, "_cidr"},
|
|
{1263, "_cstring"},
|
|
{1042, "bpchar"},
|
|
{1043, "varchar"},
|
|
{1082, "date"},
|
|
{1083, "time"},
|
|
{1114, "timestamp"},
|
|
{1115, "_timestamp"},
|
|
{1182, "_date"},
|
|
{1183, "_time"},
|
|
{1184, "timestamptz"},
|
|
{1185, "_timestamptz"},
|
|
{1186, "interval"},
|
|
{1187, "_interval"},
|
|
{1231, "_numeric"},
|
|
{1266, "timetz"},
|
|
{1270, "_timetz"},
|
|
{1560, "bit"},
|
|
{1561, "_bit"},
|
|
{1562, "varbit"},
|
|
{1563, "_varbit"},
|
|
{1700, "numeric"},
|
|
{1790, "refcursor"},
|
|
{2201, "_refcursor"},
|
|
{2202, "regprocedure"},
|
|
{2203, "regoper"},
|
|
{2204, "regoperator"},
|
|
{2205, "regclass"},
|
|
{2206, "regtype"},
|
|
{4096, "regrole"},
|
|
{4089, "regnamespace"},
|
|
{2207, "_regprocedure"},
|
|
{2208, "_regoper"},
|
|
{2209, "_regoperator"},
|
|
{2210, "_regclass"},
|
|
{2211, "_regtype"},
|
|
{4097, "_regrole"},
|
|
{4090, "_regnamespace"},
|
|
{2950, "uuid"},
|
|
{2951, "_uuid"},
|
|
{3220, "pg_lsn"},
|
|
{3221, "_pg_lsn"},
|
|
{3614, "tsvector"},
|
|
{3642, "gtsvector"},
|
|
{3615, "tsquery"},
|
|
{3734, "regconfig"},
|
|
{3769, "regdictionary"},
|
|
{3643, "_tsvector"},
|
|
{3644, "_gtsvector"},
|
|
{3645, "_tsquery"},
|
|
{3735, "_regconfig"},
|
|
{3770, "_regdictionary"},
|
|
{3802, "jsonb"},
|
|
{3807, "_jsonb"},
|
|
{2970, "txid_snapshot"},
|
|
{2949, "_txid_snapshot"},
|
|
{3904, "int4range"},
|
|
{3905, "_int4range"},
|
|
{3906, "numrange"},
|
|
{3907, "_numrange"},
|
|
{3908, "tsrange"},
|
|
{3909, "_tsrange"},
|
|
{3910, "tstzrange"},
|
|
{3911, "_tstzrange"},
|
|
{3912, "daterange"},
|
|
{3913, "_daterange"},
|
|
{3926, "int8range"},
|
|
{3927, "_int8range"},
|
|
{2249, "record"},
|
|
{2287, "_record"},
|
|
{2275, "cstring"},
|
|
{2276, "any"},
|
|
{2277, "anyarray"},
|
|
{2278, "void"},
|
|
{2279, "trigger"},
|
|
{3838, "event_trigger"},
|
|
{2280, "language_handler"},
|
|
{2281, "internal"},
|
|
{2282, "opaque"},
|
|
{2283, "anyelement"},
|
|
{2776, "anynonarray"},
|
|
{3500, "anyenum"},
|
|
{3115, "fdw_handler"},
|
|
{325, "index_am_handler"},
|
|
{3310, "tsm_handler"},
|
|
{3831, "anyrange"},
|
|
{51367, "gbtreekey4"},
|
|
{51370, "_gbtreekey4"},
|
|
{51371, "gbtreekey8"},
|
|
{51374, "_gbtreekey8"},
|
|
{51375, "gbtreekey16"},
|
|
{51378, "_gbtreekey16"},
|
|
{51379, "gbtreekey32"},
|
|
{51382, "_gbtreekey32"},
|
|
{51383, "gbtreekey_var"},
|
|
{51386, "_gbtreekey_var"},
|
|
{51921, "hstore"},
|
|
{51926, "_hstore"},
|
|
{52005, "ghstore"},
|
|
{52008, "_ghstore"},
|
|
}
|
|
|
|
for _, rv := range rowVals {
|
|
step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat}))
|
|
steps = append(steps, step)
|
|
}
|
|
|
|
steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 163")}))
|
|
steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
|
|
|
|
steps = append(steps, []Step{
|
|
ExpectMessage(&pgproto3.Parse{
|
|
Query: "select t.oid, t.typname\nfrom pg_type t\n join pg_type base_type on t.typelem=base_type.oid\nwhere t.typtype = 'b'\n and base_type.typtype = 'e'",
|
|
}),
|
|
ExpectMessage(&pgproto3.Describe{
|
|
ObjectType: 'S',
|
|
}),
|
|
ExpectMessage(&pgproto3.Sync{}),
|
|
SendMessage(&pgproto3.ParseComplete{}),
|
|
SendMessage(&pgproto3.ParameterDescription{}),
|
|
SendMessage(&pgproto3.RowDescription{
|
|
Fields: []pgproto3.FieldDescription{
|
|
{Name: []byte("oid"),
|
|
TableOID: 1247,
|
|
TableAttributeNumber: 65534,
|
|
DataTypeOID: 26,
|
|
DataTypeSize: 4,
|
|
TypeModifier: -1,
|
|
Format: 0,
|
|
},
|
|
{Name: []byte("typname"),
|
|
TableOID: 1247,
|
|
TableAttributeNumber: 1,
|
|
DataTypeOID: 19,
|
|
DataTypeSize: 64,
|
|
TypeModifier: -1,
|
|
Format: 0,
|
|
},
|
|
},
|
|
}),
|
|
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
ExpectMessage(&pgproto3.Bind{
|
|
ResultFormatCodes: []int16{1, 1},
|
|
}),
|
|
ExpectMessage(&pgproto3.Execute{}),
|
|
ExpectMessage(&pgproto3.Sync{}),
|
|
SendMessage(&pgproto3.BindComplete{}),
|
|
SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")}),
|
|
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
}...)
|
|
|
|
steps = append(steps, []Step{
|
|
ExpectMessage(&pgproto3.Parse{
|
|
Query: "select t.oid, t.typname, t.typbasetype\nfrom pg_type t\n join pg_type base_type on t.typbasetype=base_type.oid\nwhere t.typtype = 'd'\n and base_type.typtype = 'b'",
|
|
}),
|
|
ExpectMessage(&pgproto3.Describe{
|
|
ObjectType: 'S',
|
|
}),
|
|
ExpectMessage(&pgproto3.Sync{}),
|
|
SendMessage(&pgproto3.ParseComplete{}),
|
|
SendMessage(&pgproto3.ParameterDescription{}),
|
|
SendMessage(&pgproto3.RowDescription{
|
|
Fields: []pgproto3.FieldDescription{
|
|
{Name: []byte("oid"),
|
|
TableOID: 1247,
|
|
TableAttributeNumber: 65534,
|
|
DataTypeOID: 26,
|
|
DataTypeSize: 4,
|
|
TypeModifier: -1,
|
|
Format: 0,
|
|
},
|
|
{Name: []byte("typname"),
|
|
TableOID: 1247,
|
|
TableAttributeNumber: 1,
|
|
DataTypeOID: 19,
|
|
DataTypeSize: 64,
|
|
TypeModifier: -1,
|
|
Format: 0,
|
|
},
|
|
{Name: []byte("typbasetype"),
|
|
TableOID: 1247,
|
|
TableAttributeNumber: 65534,
|
|
DataTypeOID: 26,
|
|
DataTypeSize: 4,
|
|
TypeModifier: -1,
|
|
Format: 0,
|
|
},
|
|
},
|
|
}),
|
|
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
ExpectMessage(&pgproto3.Bind{
|
|
ResultFormatCodes: []int16{1, 1, 1},
|
|
}),
|
|
ExpectMessage(&pgproto3.Execute{}),
|
|
ExpectMessage(&pgproto3.Sync{}),
|
|
SendMessage(&pgproto3.BindComplete{}),
|
|
SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")}),
|
|
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
}...)
|
|
|
|
return steps
|
|
}
|
|
|
|
type dataRowValue struct {
|
|
Value interface{}
|
|
FormatCode int16
|
|
}
|
|
|
|
func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow {
|
|
dr, err := buildDataRow(values, formatCodes)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return dr
|
|
}
|
|
|
|
func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) {
|
|
dr := &pgproto3.DataRow{
|
|
Values: make([][]byte, len(values)),
|
|
}
|
|
|
|
if len(formatCodes) == 1 {
|
|
for i := 1; i < len(values); i++ {
|
|
formatCodes = append(formatCodes, formatCodes[0])
|
|
}
|
|
}
|
|
|
|
for i := range values {
|
|
switch v := values[i].(type) {
|
|
case string:
|
|
values[i] = &pgtype.Text{String: v, Status: pgtype.Present}
|
|
case int16:
|
|
values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present}
|
|
case int32:
|
|
values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present}
|
|
case int64:
|
|
values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present}
|
|
}
|
|
}
|
|
|
|
for i := range values {
|
|
switch formatCodes[i] {
|
|
case pgproto3.TextFormat:
|
|
if e, ok := values[i].(pgtype.TextEncoder); ok {
|
|
buf, err := e.EncodeText(nil, nil)
|
|
if err != nil {
|
|
return nil, errors.Errorf("failed to encode values[%d]", i)
|
|
}
|
|
dr.Values[i] = buf
|
|
} else {
|
|
return nil, errors.Errorf("values[%d] does not implement TextExcoder", i)
|
|
}
|
|
|
|
case pgproto3.BinaryFormat:
|
|
if e, ok := values[i].(pgtype.BinaryEncoder); ok {
|
|
buf, err := e.EncodeBinary(nil, nil)
|
|
if err != nil {
|
|
return nil, errors.Errorf("failed to encode values[%d]", i)
|
|
}
|
|
dr.Values[i] = buf
|
|
} else {
|
|
return nil, errors.Errorf("values[%d] does not implement BinaryEncoder", i)
|
|
}
|
|
default:
|
|
return nil, errors.New("unknown FormatCode")
|
|
}
|
|
}
|
|
|
|
return dr, nil
|
|
}
|