2
0

json: fix implementation of json Unmarshalers.

* AuthenticationMD5Password was wrong and is not needed
* Bind was wrong
* ErrorResponse is not needed
* Minor improvements for reliability
This commit is contained in:
Henrique Vicente
2021-05-17 02:11:29 +02:00
committed by Jack Christensen
parent ba924e5715
commit 9c2c389e06
7 changed files with 530 additions and 106 deletions
-22
View File
@@ -2,7 +2,6 @@ package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
@@ -42,24 +41,3 @@ func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, src.Salt[:]...)
return dst
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Salt string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.Salt) != 4 {
return errors.New("invalid salt size")
}
copy(dst.Salt[:], []byte(msg.Salt)[:4])
return nil
}
+6 -8
View File
@@ -201,15 +201,13 @@ func (dst *Bind) UnmarshalJSON(data []byte) error {
if err != nil {
return err
}
bind := &Bind{
DestinationPortal: msg.DestinationPortal,
PreparedStatement: msg.PreparedStatement,
ParameterFormatCodes: msg.ParameterFormatCodes,
Parameters: make([][]byte, len(msg.Parameters)),
ResultFormatCodes: msg.ResultFormatCodes,
}
dst.DestinationPortal = msg.DestinationPortal
dst.PreparedStatement = msg.PreparedStatement
dst.ParameterFormatCodes = msg.ParameterFormatCodes
dst.Parameters = make([][]byte, len(msg.Parameters))
dst.ResultFormatCodes = msg.ResultFormatCodes
for n, parameter := range msg.Parameters {
bind.Parameters[n], err = getValueFromJSON(parameter)
dst.Parameters[n], err = getValueFromJSON(parameter)
if err != nil {
return fmt.Errorf("cannot get param %d: %w", n, err)
}
-67
View File
@@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"strconv"
)
@@ -227,68 +225,3 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
return buf.Bytes()
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *ErrorResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[string]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Severity = msg.Severity
dst.SeverityUnlocalized = msg.SeverityUnlocalized
dst.Code = msg.Code
dst.Message = msg.Message
dst.Detail = msg.Detail
dst.Hint = msg.Hint
dst.Position = msg.Position
dst.InternalPosition = msg.InternalPosition
dst.InternalQuery = msg.InternalQuery
dst.Where = msg.Where
dst.SchemaName = msg.SchemaName
dst.TableName = msg.TableName
dst.ColumnName = msg.ColumnName
dst.DataTypeName = msg.DataTypeName
dst.ConstraintName = msg.ConstraintName
dst.File = msg.File
dst.Line = msg.Line
dst.Routine = msg.Routine
if msg.UnknownFields != nil {
dst.UnknownFields = map[byte]string{}
}
for k, v := range msg.UnknownFields {
if len(k) != 1 {
return fmt.Errorf("invalid UnknownFields field %q value", k)
}
dst.UnknownFields[k[0]] = v
}
return nil
}
+508
View File
@@ -0,0 +1,508 @@
package pgproto3
import (
"encoding/hex"
"encoding/json"
"reflect"
"testing"
)
func TestJSONUnmarshalAuthenticationMD5Password(t *testing.T) {
data := []byte(`{"Type":"AuthenticationMD5Password", "Salt":[97,98,99,100]}`)
want := AuthenticationMD5Password{
Salt: [4]byte{'a', 'b', 'c', 'd'},
}
var got AuthenticationMD5Password
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value")
}
}
func TestJSONUnmarshalAuthenticationSASL(t *testing.T) {
data := []byte(`{"Type":"AuthenticationSASL", "AuthMechanisms":[]}`)
want := AuthenticationSASL{
AuthMechanisms: []string{},
}
var got AuthenticationSASL
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled AuthenticationSASL struct doesn't match expected value")
}
}
func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) {
data := []byte(`{"Type":"AuthenticationSASLContinue"}`)
want := AuthenticationSASLContinue{
Data: []byte{},
}
var got AuthenticationSASLContinue
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled AuthenticationSASLContinue struct doesn't match expected value")
}
}
func TestJSONUnmarshalAuthenticationSASLFinal(t *testing.T) {
data := []byte(`{"Type":"AuthenticationSASLFinal"}`)
want := AuthenticationSASLFinal{
Data: []byte{},
}
var got AuthenticationSASLFinal
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled AuthenticationSASLFinal struct doesn't match expected value")
}
}
func TestJSONUnmarshalBackendKeyData(t *testing.T) {
data := []byte(`{"Type":"BackendKeyData","ProcessID":8864,"SecretKey":3641487067}`)
want := BackendKeyData{
ProcessID: 8864,
SecretKey: 3641487067,
}
var got BackendKeyData
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled BackendKeyData struct doesn't match expected value")
}
}
func TestJSONUnmarshalCommandComplete(t *testing.T) {
data := []byte(`{"Type":"CommandComplete","CommandTag":"SELECT 1"}`)
want := CommandComplete{
CommandTag: []byte("SELECT 1"),
}
var got CommandComplete
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled CommandComplete struct doesn't match expected value")
}
}
func TestJSONUnmarshalCopyBothResponse(t *testing.T) {
data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`)
want := CopyBothResponse{
OverallFormat: 'W',
}
var got CopyBothResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled CopyBothResponse struct doesn't match expected value")
}
}
func TestJSONUnmarshalCopyData(t *testing.T) {
data := []byte(`{"Type":"CopyData"}`)
want := CopyData{
Data: []byte{},
}
var got CopyData
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled CopyData struct doesn't match expected value")
}
}
func TestJSONUnmarshalCopyInResponse(t *testing.T) {
data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`)
want := CopyBothResponse{
OverallFormat: 'W',
}
var got CopyBothResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled CopyBothResponse struct doesn't match expected value")
}
}
func TestJSONUnmarshalCopyOutResponse(t *testing.T) {
data := []byte(`{"Type":"CopyOutResponse", "OverallFormat": "W"}`)
want := CopyOutResponse{
OverallFormat: 'W',
}
var got CopyOutResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled CopyOutResponse struct doesn't match expected value")
}
}
func TestJSONUnmarshalDataRow(t *testing.T) {
data := []byte(`{"Type":"DataRow","Values":[{"text":"abc"},{"text":"this is a test"},{"binary":"000263d3114d2e34"}]}`)
want := DataRow{
Values: [][]byte{
[]byte("abc"),
[]byte("this is a test"),
{0, 2, 99, 211, 17, 77, 46, 52},
},
}
var got DataRow
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled DataRow struct doesn't match expected value")
}
}
func TestJSONUnmarshalErrorResponse(t *testing.T) {
data := []byte(`{"Type":"ErrorResponse", "UnknownFields": {"97": "foo"}}`)
want := ErrorResponse{
UnknownFields: map[byte]string{
'a': "foo",
},
}
var got ErrorResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled ErrorResponse struct doesn't match expected value")
}
}
func TestJSONUnmarshalFunctionCallResponse(t *testing.T) {
data := []byte(`{"Type":"FunctionCallResponse"}`)
want := FunctionCallResponse{}
var got FunctionCallResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled FunctionCallResponse struct doesn't match expected value")
}
}
func TestJSONUnmarshalNoticeResponse(t *testing.T) {
data := []byte(`{"Type":"NoticeResponse", "UnknownFields": {"97": "foo"}}`)
want := NoticeResponse{
UnknownFields: map[byte]string{
'a': "foo",
},
}
var got NoticeResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled NoticeResponse struct doesn't match expected value")
}
}
func TestJSONUnmarshalNotificationResponse(t *testing.T) {
data := []byte(`{"Type":"NotificationResponse"}`)
want := NotificationResponse{}
var got NotificationResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled NotificationResponse struct doesn't match expected value")
}
}
func TestJSONUnmarshalParameterDescription(t *testing.T) {
data := []byte(`{"Type":"ParameterDescription", "ParameterOIDs": [25]}`)
want := ParameterDescription{
ParameterOIDs: []uint32{25},
}
var got ParameterDescription
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled ParameterDescription struct doesn't match expected value")
}
}
func TestJSONUnmarshalParameterStatus(t *testing.T) {
data := []byte(`{"Type":"ParameterStatus","Name":"TimeZone","Value":"Europe/Amsterdam"}`)
want := ParameterStatus{
Name: "TimeZone",
Value: "Europe/Amsterdam",
}
var got ParameterStatus
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled ParameterDescription struct doesn't match expected value")
}
}
func TestJSONUnmarshalReadyForQuery(t *testing.T) {
data := []byte(`{"Type":"ReadyForQuery","TxStatus":"I"}`)
want := ReadyForQuery{
TxStatus: 'I',
}
var got ReadyForQuery
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled ParameterDescription struct doesn't match expected value")
}
}
func TestJSONUnmarshalRowDescription(t *testing.T) {
data := []byte(`{"Type":"RowDescription","Fields":[{"Name":"generate_series","TableOID":0,"TableAttributeNumber":0,"DataTypeOID":23,"DataTypeSize":4,"TypeModifier":-1,"Format":0}]}`)
want := RowDescription{
Fields: []FieldDescription{
{
Name: []byte("generate_series"),
DataTypeOID: 23,
DataTypeSize: 4,
TypeModifier: -1,
},
},
}
var got RowDescription
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled RowDescription struct doesn't match expected value")
}
}
func TestJSONUnmarshalBind(t *testing.T) {
var testCases = []struct {
desc string
data []byte
}{
{
"textual",
[]byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"text":"ABC-123"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`),
},
{
"binary",
[]byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"binary":"` + hex.EncodeToString([]byte("ABC-123")) + `"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`),
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
var want = Bind{
PreparedStatement: "lrupsc_1_0",
ParameterFormatCodes: []int16{0},
Parameters: [][]byte{[]byte("ABC-123")},
ResultFormatCodes: []int16{0, 0, 0, 0, 0, 1, 1},
}
var got Bind
if err := json.Unmarshal(tc.data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled Bind struct doesn't match expected value")
}
})
}
}
func TestJSONUnmarshalCancelRequest(t *testing.T) {
data := []byte(`{"Type":"CancelRequest","ProcessID":8864,"SecretKey":3641487067}`)
want := CancelRequest{
ProcessID: 8864,
SecretKey: 3641487067,
}
var got CancelRequest
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled CancelRequest struct doesn't match expected value")
}
}
func TestJSONUnmarshalClose(t *testing.T) {
data := []byte(`{"Type":"Close","ObjectType":"S","Name":"abc"}`)
want := Close{
ObjectType: 'S',
Name: "abc",
}
var got Close
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled Close struct doesn't match expected value")
}
}
func TestJSONUnmarshalCopyFail(t *testing.T) {
data := []byte(`{"Type":"CopyFail","Message":"abc"}`)
want := CopyFail{
Message: "abc",
}
var got CopyFail
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled CopyFail struct doesn't match expected value")
}
}
func TestJSONUnmarshalDescribe(t *testing.T) {
data := []byte(`{"Type":"Describe","ObjectType":"S","Name":"abc"}`)
want := Describe{
ObjectType: 'S',
Name: "abc",
}
var got Describe
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled Describe struct doesn't match expected value")
}
}
func TestJSONUnmarshalExecute(t *testing.T) {
data := []byte(`{"Type":"Execute","Portal":"","MaxRows":0}`)
want := Execute{}
var got Execute
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled Execute struct doesn't match expected value")
}
}
func TestJSONUnmarshalParse(t *testing.T) {
data := []byte(`{"Type":"Parse","Name":"lrupsc_1_0","Query":"SELECT id, name FROM t WHERE id = $1","ParameterOIDs":null}`)
want := Parse{
Name: "lrupsc_1_0",
Query: "SELECT id, name FROM t WHERE id = $1",
}
var got Parse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled Parse struct doesn't match expected value")
}
}
func TestJSONUnmarshalPasswordMessage(t *testing.T) {
data := []byte(`{"Type":"PasswordMessage","Password":"abcdef"}`)
want := PasswordMessage{
Password: "abcdef",
}
var got PasswordMessage
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled PasswordMessage struct doesn't match expected value")
}
}
func TestJSONUnmarshalQuery(t *testing.T) {
data := []byte(`{"Type":"Query","String":"SELECT 1"}`)
want := Query{
String: "SELECT 1",
}
var got Query
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled Query struct doesn't match expected value")
}
}
func TestJSONUnmarshalSASLInitialResponse(t *testing.T) {
data := []byte(`{"Type":"SASLInitialResponse"}`)
want := SASLInitialResponse{}
var got SASLInitialResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled SASLInitialResponse struct doesn't match expected value")
}
}
func TestJSONUnmarshalSASLResponse(t *testing.T) {
data := []byte(`{"Type":"SASLResponse","Message":"abc"}`)
want := SASLResponse{}
var got SASLResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled SASLResponse struct doesn't match expected value")
}
}
func TestJSONUnmarshalStartupMessage(t *testing.T) {
data := []byte(`{"Type":"StartupMessage","ProtocolVersion":196608,"Parameters":{"database":"testing","user":"postgres"}}`)
want := StartupMessage{
ProtocolVersion: 196608,
Parameters: map[string]string{
"database": "testing",
"user": "postgres",
},
}
var got StartupMessage
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled StartupMessage struct doesn't match expected value")
}
}
+3
View File
@@ -47,6 +47,9 @@ func (e *invalidMessageFormatErr) Error() string {
// getValueFromJSON gets the value from a protocol message representation in JSON.
func getValueFromJSON(v map[string]string) ([]byte, error) {
if v == nil {
return nil, nil
}
if text, ok := v["text"]; ok {
return []byte(text), nil
}
+7 -5
View File
@@ -82,11 +82,13 @@ func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
decodedData, err := hex.DecodeString(msg.Data)
if err != nil {
return err
}
dst.AuthMechanism = msg.AuthMechanism
dst.Data = decodedData
if msg.Data != "" {
decoded, err := hex.DecodeString(msg.Data)
if err != nil {
return err
}
dst.Data = decoded
}
return nil
}
+6 -4
View File
@@ -50,10 +50,12 @@ func (dst *SASLResponse) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
decoded, err := hex.DecodeString(msg.Data)
if err != nil {
return err
if msg.Data != "" {
decoded, err := hex.DecodeString(msg.Data)
if err != nil {
return err
}
dst.Data = decoded
}
dst.Data = decoded
return nil
}