From 9275da562f356273288268731252a8f1be60d5f4 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Wed, 3 Nov 2021 19:43:46 +0000 Subject: [PATCH] Added FunctionCall support Added support for FunctionCall message as per https://www.postgresql.org/docs/11/protocol-message-formats.html Adds unit test for Encode / Decode cycle and invalid message format errors. Fixes https://github.com/jackc/pgproto3/issues/23 --- backend.go | 12 +++-- function_call.go | 104 ++++++++++++++++++++++++++++++++++++++++++ function_call_test.go | 44 ++++++++++++++++++ go.mod | 1 + go.sum | 3 ++ 5 files changed, 160 insertions(+), 4 deletions(-) create mode 100644 function_call.go create mode 100644 function_call_test.go diff --git a/backend.go b/backend.go index e9ba38fc..9c42ad02 100644 --- a/backend.go +++ b/backend.go @@ -21,6 +21,7 @@ type Backend struct { describe Describe execute Execute flush Flush + functionCall FunctionCall gssEncRequest GSSEncRequest parse Parse query Query @@ -29,10 +30,11 @@ type Backend struct { sync Sync terminate Terminate - bodyLen int - msgType byte - partialMsg bool - authType uint32 + bodyLen int + msgType byte + partialMsg bool + authType uint32 + } const ( @@ -125,6 +127,8 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &b.describe case 'E': msg = &b.execute + case 'F': + msg = &b.functionCall case 'f': msg = &b.copyFail case 'd': diff --git a/function_call.go b/function_call.go new file mode 100644 index 00000000..74d3c3c7 --- /dev/null +++ b/function_call.go @@ -0,0 +1,104 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "github.com/jackc/pgio" +) + +type FunctionCall struct{ + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte + ResultFormatCode uint16 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*FunctionCall) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *FunctionCall) Decode(src []byte) error { + *dst = FunctionCall{} + rp := 0 + // Specifies the object ID of the function to call. + dst.Function = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + // The number of argument format codes that follow (denoted C below). + // This can be zero to indicate that there are no arguments or that the arguments all use the default format (text); + // or one, in which case the specified format code is applied to all arguments; + // or it can equal the actual number of arguments. + nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + argumentCodes := make([]uint16, nArgumentCodes) + for i := 0; i < nArgumentCodes; i++ { + // The argument format codes. Each must presently be zero (text) or one (binary). + ac := binary.BigEndian.Uint16(src[rp:]) + if ac != 0 && ac != 1 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + argumentCodes[i] = ac + rp += 2 + } + dst.ArgFormatCodes = argumentCodes + + // Specifies the number of arguments being supplied to the function. + nArguments := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + arguments := make([][]byte, nArguments) + for i := 0; i < nArguments; i++ { + // The length of the argument value, in bytes (this count does not include itself). Can be zero. + // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. + argumentLength := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if argumentLength == -1 { + arguments[i] = nil + } else { + // The value of the argument, in the format indicated by the associated format code. n is the above length. + argumentValue := src[rp:rp+argumentLength] + rp += argumentLength + arguments[i] = argumentValue + } + } + dst.Arguments = arguments + // The format code for the function result. Must presently be zero (text) or one (binary). + resultFormatCode := binary.BigEndian.Uint16(src[rp:]) + if resultFormatCode != 0 && resultFormatCode != 1 { + return &invalidMessageFormatErr{messageType: "FunctionCall"} + } + dst.ResultFormatCode = resultFormatCode + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *FunctionCall) Encode(dst []byte) []byte { + dst = append(dst, 'F') + sp := len(dst) + dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end + dst = pgio.AppendUint32(dst, src.Function) + dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) + for _, argFormatCode := range src.ArgFormatCodes { + dst = pgio.AppendUint16(dst, argFormatCode) + } + dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) + for _, argument := range src.Arguments { + if argument == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(argument))) + dst = append(dst, argument...) + } + } + dst = pgio.AppendUint16(dst, src.ResultFormatCode) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src FunctionCall) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "FunctionCall", + }) +} diff --git a/function_call_test.go b/function_call_test.go new file mode 100644 index 00000000..f1586560 --- /dev/null +++ b/function_call_test.go @@ -0,0 +1,44 @@ +package pgproto3 + +import ( + "github.com/go-test/deep" + "testing" +) + +func TestFunctionCall_EncodeDecode(t *testing.T) { + type fields struct { + Function uint32 + ArgFormatCodes []uint16 + Arguments [][]byte + ResultFormatCode uint16 + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"foo", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, false}, + {"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true}, + {"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + src := &FunctionCall{ + Function: tt.fields.Function, + ArgFormatCodes: tt.fields.ArgFormatCodes, + Arguments: tt.fields.Arguments, + ResultFormatCode: tt.fields.ResultFormatCode, + } + encoded := src.Encode([]byte{}) + decoded := &FunctionCall{} + err := decoded.Decode(encoded[5:]) + if (err != nil) != tt.wantErr { + t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := deep.Equal(src, decoded); diff != nil { + t.Error(diff) + } + }) + } +} \ No newline at end of file diff --git a/go.mod b/go.mod index 36041a94..030953b5 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/jackc/pgproto3/v2 go 1.12 require ( + github.com/go-test/deep v1.0.8 github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index dd9cd044..765190c1 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= @@ -9,6 +11,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=