From 9c2c389e06738fc2fb5e3c15b5d51b125435b5a0 Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Mon, 17 May 2021 02:11:29 +0200 Subject: [PATCH] json: fix implementation of json Unmarshalers. * AuthenticationMD5Password was wrong and is not needed * Bind was wrong * ErrorResponse is not needed * Minor improvements for reliability --- authentication_md5_password.go | 22 -- bind.go | 14 +- error_response.go | 67 ----- json_test.go | 508 +++++++++++++++++++++++++++++++++ pgproto3.go | 3 + sasl_initial_response.go | 12 +- sasl_response.go | 10 +- 7 files changed, 530 insertions(+), 106 deletions(-) create mode 100644 json_test.go diff --git a/authentication_md5_password.go b/authentication_md5_password.go index b80bd992..d505d264 100644 --- a/authentication_md5_password.go +++ b/authentication_md5_password.go @@ -2,7 +2,6 @@ package pgproto3 import ( "encoding/binary" - "encoding/json" "errors" "github.com/jackc/pgio" @@ -42,24 +41,3 @@ func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { dst = append(dst, src.Salt[:]...) return dst } - -// UnmarshalJSON implements encoding/json.Unmarshaler. -func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { - // Ignore null, like in the main JSON package. - if string(data) == "null" { - return nil - } - - var msg struct { - Salt string - } - if err := json.Unmarshal(data, &msg); err != nil { - return err - } - if len(msg.Salt) != 4 { - return errors.New("invalid salt size") - } - - copy(dst.Salt[:], []byte(msg.Salt)[:4]) - return nil -} diff --git a/bind.go b/bind.go index 57585c4d..e9664f59 100644 --- a/bind.go +++ b/bind.go @@ -201,15 +201,13 @@ func (dst *Bind) UnmarshalJSON(data []byte) error { if err != nil { return err } - bind := &Bind{ - DestinationPortal: msg.DestinationPortal, - PreparedStatement: msg.PreparedStatement, - ParameterFormatCodes: msg.ParameterFormatCodes, - Parameters: make([][]byte, len(msg.Parameters)), - ResultFormatCodes: msg.ResultFormatCodes, - } + dst.DestinationPortal = msg.DestinationPortal + dst.PreparedStatement = msg.PreparedStatement + dst.ParameterFormatCodes = msg.ParameterFormatCodes + dst.Parameters = make([][]byte, len(msg.Parameters)) + dst.ResultFormatCodes = msg.ResultFormatCodes for n, parameter := range msg.Parameters { - bind.Parameters[n], err = getValueFromJSON(parameter) + dst.Parameters[n], err = getValueFromJSON(parameter) if err != nil { return fmt.Errorf("cannot get param %d: %w", n, err) } diff --git a/error_response.go b/error_response.go index 9bbd78f4..4eb0a196 100644 --- a/error_response.go +++ b/error_response.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/binary" - "encoding/json" - "fmt" "strconv" ) @@ -227,68 +225,3 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { return buf.Bytes() } - -// UnmarshalJSON implements encoding/json.Unmarshaler. -func (dst *ErrorResponse) UnmarshalJSON(data []byte) error { - // Ignore null, like in the main JSON package. - if string(data) == "null" { - return nil - } - - var msg struct { - Severity string - SeverityUnlocalized string // only in 9.6 and greater - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string - - UnknownFields map[string]string - } - if err := json.Unmarshal(data, &msg); err != nil { - return err - } - - dst.Severity = msg.Severity - dst.SeverityUnlocalized = msg.SeverityUnlocalized - dst.Code = msg.Code - dst.Message = msg.Message - dst.Detail = msg.Detail - dst.Hint = msg.Hint - dst.Position = msg.Position - dst.InternalPosition = msg.InternalPosition - dst.InternalQuery = msg.InternalQuery - dst.Where = msg.Where - dst.SchemaName = msg.SchemaName - dst.TableName = msg.TableName - dst.ColumnName = msg.ColumnName - dst.DataTypeName = msg.DataTypeName - dst.ConstraintName = msg.ConstraintName - dst.File = msg.File - dst.Line = msg.Line - dst.Routine = msg.Routine - - if msg.UnknownFields != nil { - dst.UnknownFields = map[byte]string{} - } - for k, v := range msg.UnknownFields { - if len(k) != 1 { - return fmt.Errorf("invalid UnknownFields field %q value", k) - } - dst.UnknownFields[k[0]] = v - } - - return nil -} diff --git a/json_test.go b/json_test.go new file mode 100644 index 00000000..c73807ab --- /dev/null +++ b/json_test.go @@ -0,0 +1,508 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + "reflect" + "testing" +) + +func TestJSONUnmarshalAuthenticationMD5Password(t *testing.T) { + data := []byte(`{"Type":"AuthenticationMD5Password", "Salt":[97,98,99,100]}`) + want := AuthenticationMD5Password{ + Salt: [4]byte{'a', 'b', 'c', 'd'}, + } + + var got AuthenticationMD5Password + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASL", "AuthMechanisms":[]}`) + want := AuthenticationSASL{ + AuthMechanisms: []string{}, + } + + var got AuthenticationSASL + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASL struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASLContinue"}`) + want := AuthenticationSASLContinue{ + Data: []byte{}, + } + + var got AuthenticationSASLContinue + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASLContinue struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationSASLFinal(t *testing.T) { + data := []byte(`{"Type":"AuthenticationSASLFinal"}`) + want := AuthenticationSASLFinal{ + Data: []byte{}, + } + + var got AuthenticationSASLFinal + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationSASLFinal struct doesn't match expected value") + } +} + +func TestJSONUnmarshalBackendKeyData(t *testing.T) { + data := []byte(`{"Type":"BackendKeyData","ProcessID":8864,"SecretKey":3641487067}`) + want := BackendKeyData{ + ProcessID: 8864, + SecretKey: 3641487067, + } + + var got BackendKeyData + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled BackendKeyData struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCommandComplete(t *testing.T) { + data := []byte(`{"Type":"CommandComplete","CommandTag":"SELECT 1"}`) + want := CommandComplete{ + CommandTag: []byte("SELECT 1"), + } + + var got CommandComplete + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CommandComplete struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyBothResponse(t *testing.T) { + data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`) + want := CopyBothResponse{ + OverallFormat: 'W', + } + + var got CopyBothResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyBothResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyData(t *testing.T) { + data := []byte(`{"Type":"CopyData"}`) + want := CopyData{ + Data: []byte{}, + } + + var got CopyData + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyData struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyInResponse(t *testing.T) { + data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`) + want := CopyBothResponse{ + OverallFormat: 'W', + } + + var got CopyBothResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyBothResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyOutResponse(t *testing.T) { + data := []byte(`{"Type":"CopyOutResponse", "OverallFormat": "W"}`) + want := CopyOutResponse{ + OverallFormat: 'W', + } + + var got CopyOutResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyOutResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalDataRow(t *testing.T) { + data := []byte(`{"Type":"DataRow","Values":[{"text":"abc"},{"text":"this is a test"},{"binary":"000263d3114d2e34"}]}`) + want := DataRow{ + Values: [][]byte{ + []byte("abc"), + []byte("this is a test"), + {0, 2, 99, 211, 17, 77, 46, 52}, + }, + } + + var got DataRow + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled DataRow struct doesn't match expected value") + } +} + +func TestJSONUnmarshalErrorResponse(t *testing.T) { + data := []byte(`{"Type":"ErrorResponse", "UnknownFields": {"97": "foo"}}`) + want := ErrorResponse{ + UnknownFields: map[byte]string{ + 'a': "foo", + }, + } + + var got ErrorResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ErrorResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalFunctionCallResponse(t *testing.T) { + data := []byte(`{"Type":"FunctionCallResponse"}`) + want := FunctionCallResponse{} + + var got FunctionCallResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled FunctionCallResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalNoticeResponse(t *testing.T) { + data := []byte(`{"Type":"NoticeResponse", "UnknownFields": {"97": "foo"}}`) + want := NoticeResponse{ + UnknownFields: map[byte]string{ + 'a': "foo", + }, + } + + var got NoticeResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled NoticeResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalNotificationResponse(t *testing.T) { + data := []byte(`{"Type":"NotificationResponse"}`) + want := NotificationResponse{} + + var got NotificationResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled NotificationResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParameterDescription(t *testing.T) { + data := []byte(`{"Type":"ParameterDescription", "ParameterOIDs": [25]}`) + want := ParameterDescription{ + ParameterOIDs: []uint32{25}, + } + + var got ParameterDescription + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParameterStatus(t *testing.T) { + data := []byte(`{"Type":"ParameterStatus","Name":"TimeZone","Value":"Europe/Amsterdam"}`) + want := ParameterStatus{ + Name: "TimeZone", + Value: "Europe/Amsterdam", + } + + var got ParameterStatus + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalReadyForQuery(t *testing.T) { + data := []byte(`{"Type":"ReadyForQuery","TxStatus":"I"}`) + want := ReadyForQuery{ + TxStatus: 'I', + } + + var got ReadyForQuery + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ParameterDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalRowDescription(t *testing.T) { + data := []byte(`{"Type":"RowDescription","Fields":[{"Name":"generate_series","TableOID":0,"TableAttributeNumber":0,"DataTypeOID":23,"DataTypeSize":4,"TypeModifier":-1,"Format":0}]}`) + want := RowDescription{ + Fields: []FieldDescription{ + { + Name: []byte("generate_series"), + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: -1, + }, + }, + } + + var got RowDescription + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled RowDescription struct doesn't match expected value") + } +} + +func TestJSONUnmarshalBind(t *testing.T) { + var testCases = []struct { + desc string + data []byte + }{ + { + "textual", + []byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"text":"ABC-123"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`), + }, + { + "binary", + []byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"binary":"` + hex.EncodeToString([]byte("ABC-123")) + `"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + var want = Bind{ + PreparedStatement: "lrupsc_1_0", + ParameterFormatCodes: []int16{0}, + Parameters: [][]byte{[]byte("ABC-123")}, + ResultFormatCodes: []int16{0, 0, 0, 0, 0, 1, 1}, + } + + var got Bind + if err := json.Unmarshal(tc.data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Bind struct doesn't match expected value") + } + }) + } +} + +func TestJSONUnmarshalCancelRequest(t *testing.T) { + data := []byte(`{"Type":"CancelRequest","ProcessID":8864,"SecretKey":3641487067}`) + want := CancelRequest{ + ProcessID: 8864, + SecretKey: 3641487067, + } + + var got CancelRequest + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CancelRequest struct doesn't match expected value") + } +} + +func TestJSONUnmarshalClose(t *testing.T) { + data := []byte(`{"Type":"Close","ObjectType":"S","Name":"abc"}`) + want := Close{ + ObjectType: 'S', + Name: "abc", + } + + var got Close + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Close struct doesn't match expected value") + } +} + +func TestJSONUnmarshalCopyFail(t *testing.T) { + data := []byte(`{"Type":"CopyFail","Message":"abc"}`) + want := CopyFail{ + Message: "abc", + } + + var got CopyFail + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled CopyFail struct doesn't match expected value") + } +} + +func TestJSONUnmarshalDescribe(t *testing.T) { + data := []byte(`{"Type":"Describe","ObjectType":"S","Name":"abc"}`) + want := Describe{ + ObjectType: 'S', + Name: "abc", + } + + var got Describe + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Describe struct doesn't match expected value") + } +} + +func TestJSONUnmarshalExecute(t *testing.T) { + data := []byte(`{"Type":"Execute","Portal":"","MaxRows":0}`) + want := Execute{} + + var got Execute + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Execute struct doesn't match expected value") + } +} + +func TestJSONUnmarshalParse(t *testing.T) { + data := []byte(`{"Type":"Parse","Name":"lrupsc_1_0","Query":"SELECT id, name FROM t WHERE id = $1","ParameterOIDs":null}`) + want := Parse{ + Name: "lrupsc_1_0", + Query: "SELECT id, name FROM t WHERE id = $1", + } + + var got Parse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Parse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalPasswordMessage(t *testing.T) { + data := []byte(`{"Type":"PasswordMessage","Password":"abcdef"}`) + want := PasswordMessage{ + Password: "abcdef", + } + + var got PasswordMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled PasswordMessage struct doesn't match expected value") + } +} + +func TestJSONUnmarshalQuery(t *testing.T) { + data := []byte(`{"Type":"Query","String":"SELECT 1"}`) + want := Query{ + String: "SELECT 1", + } + + var got Query + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled Query struct doesn't match expected value") + } +} + +func TestJSONUnmarshalSASLInitialResponse(t *testing.T) { + data := []byte(`{"Type":"SASLInitialResponse"}`) + want := SASLInitialResponse{} + + var got SASLInitialResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled SASLInitialResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalSASLResponse(t *testing.T) { + data := []byte(`{"Type":"SASLResponse","Message":"abc"}`) + want := SASLResponse{} + + var got SASLResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled SASLResponse struct doesn't match expected value") + } +} + +func TestJSONUnmarshalStartupMessage(t *testing.T) { + data := []byte(`{"Type":"StartupMessage","ProtocolVersion":196608,"Parameters":{"database":"testing","user":"postgres"}}`) + want := StartupMessage{ + ProtocolVersion: 196608, + Parameters: map[string]string{ + "database": "testing", + "user": "postgres", + }, + } + + var got StartupMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled StartupMessage struct doesn't match expected value") + } +} diff --git a/pgproto3.go b/pgproto3.go index 5b39362c..fb0782cf 100644 --- a/pgproto3.go +++ b/pgproto3.go @@ -47,6 +47,9 @@ func (e *invalidMessageFormatErr) Error() string { // getValueFromJSON gets the value from a protocol message representation in JSON. func getValueFromJSON(v map[string]string) ([]byte, error) { + if v == nil { + return nil, nil + } if text, ok := v["text"]; ok { return []byte(text), nil } diff --git a/sasl_initial_response.go b/sasl_initial_response.go index ce994c51..f7e5f36a 100644 --- a/sasl_initial_response.go +++ b/sasl_initial_response.go @@ -82,11 +82,13 @@ func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &msg); err != nil { return err } - decodedData, err := hex.DecodeString(msg.Data) - if err != nil { - return err - } dst.AuthMechanism = msg.AuthMechanism - dst.Data = decodedData + if msg.Data != "" { + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded + } return nil } diff --git a/sasl_response.go b/sasl_response.go index df60c5f7..41fb4c39 100644 --- a/sasl_response.go +++ b/sasl_response.go @@ -50,10 +50,12 @@ func (dst *SASLResponse) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &msg); err != nil { return err } - decoded, err := hex.DecodeString(msg.Data) - if err != nil { - return err + if msg.Data != "" { + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded } - dst.Data = decoded return nil }