Import pgproto3
Also copy in pgmock as an internal package.
This commit is contained in:
@@ -0,0 +1,135 @@
|
||||
// Package pgmock provides the ability to mock a PostgreSQL server.
|
||||
package pgmock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
type Step interface {
|
||||
Step(*pgproto3.Backend) error
|
||||
}
|
||||
|
||||
type Script struct {
|
||||
Steps []Step
|
||||
}
|
||||
|
||||
func (s *Script) Run(backend *pgproto3.Backend) error {
|
||||
for _, step := range s.Steps {
|
||||
err := step.Step(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Script) Step(backend *pgproto3.Backend) error {
|
||||
return s.Run(backend)
|
||||
}
|
||||
|
||||
type expectMessageStep struct {
|
||||
want pgproto3.FrontendMessage
|
||||
any bool
|
||||
}
|
||||
|
||||
func (e *expectMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
msg, err := backend.Receive()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(msg, e.want) {
|
||||
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type expectStartupMessageStep struct {
|
||||
want *pgproto3.StartupMessage
|
||||
any bool
|
||||
}
|
||||
|
||||
func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
msg, err := backend.ReceiveStartupMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.any {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(msg, e.want) {
|
||||
return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExpectMessage(want pgproto3.FrontendMessage) Step {
|
||||
return expectMessage(want, false)
|
||||
}
|
||||
|
||||
func ExpectAnyMessage(want pgproto3.FrontendMessage) Step {
|
||||
return expectMessage(want, true)
|
||||
}
|
||||
|
||||
func expectMessage(want pgproto3.FrontendMessage, any bool) Step {
|
||||
if want, ok := want.(*pgproto3.StartupMessage); ok {
|
||||
return &expectStartupMessageStep{want: want, any: any}
|
||||
}
|
||||
|
||||
return &expectMessageStep{want: want, any: any}
|
||||
}
|
||||
|
||||
type sendMessageStep struct {
|
||||
msg pgproto3.BackendMessage
|
||||
}
|
||||
|
||||
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
return backend.Send(e.msg)
|
||||
}
|
||||
|
||||
func SendMessage(msg pgproto3.BackendMessage) Step {
|
||||
return &sendMessageStep{msg: msg}
|
||||
}
|
||||
|
||||
type waitForCloseMessageStep struct{}
|
||||
|
||||
func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error {
|
||||
for {
|
||||
msg, err := backend.Receive()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := msg.(*pgproto3.Terminate); ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WaitForClose() Step {
|
||||
return &waitForCloseMessageStep{}
|
||||
}
|
||||
|
||||
func AcceptUnauthenticatedConnRequestSteps() []Step {
|
||||
return []Step{
|
||||
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||
SendMessage(&pgproto3.AuthenticationOk{}),
|
||||
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
||||
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package pgmock_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgmock"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestScript(t *testing.T) {
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{
|
||||
Fields: []pgproto3.FieldDescription{
|
||||
pgproto3.FieldDescription{
|
||||
Name: []byte("?column?"),
|
||||
TableOID: 0,
|
||||
TableAttributeNumber: 0,
|
||||
DataTypeOID: 23,
|
||||
DataTypeSize: 4,
|
||||
TypeModifier: -1,
|
||||
Format: 0,
|
||||
},
|
||||
},
|
||||
}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{
|
||||
Values: [][]byte{[]byte("42")},
|
||||
}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}))
|
||||
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
|
||||
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{}))
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrChan)
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(time.Second))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
parts := strings.Split(ln.Addr().String(), ":")
|
||||
host := parts[0]
|
||||
port := parts[1]
|
||||
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
pgConn, err := pgconn.Connect(ctx, connStr)
|
||||
require.NoError(t, err)
|
||||
results, err := pgConn.Exec(ctx, "select 42").ReadAll()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, results, 1)
|
||||
assert.Nil(t, results[0].Err)
|
||||
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
||||
assert.Len(t, results[0].Rows, 1)
|
||||
assert.Equal(t, "42", string(results[0].Rows[0][0]))
|
||||
|
||||
pgConn.Close(ctx)
|
||||
|
||||
assert.NoError(t, <-serverErrChan)
|
||||
}
|
||||
Reference in New Issue
Block a user