From ba924e5715ad0b06cf1e1ddcc343ded6e9420cf4 Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Sun, 16 May 2021 02:05:24 +0200 Subject: [PATCH] json: Implement json.Unmarshaler for messages. This will allow using pgmockproxy output as ingestion data for pgmock. --- authentication_md5_password.go | 22 +++++++++++ authentication_sasl_continue.go | 19 ++++++++++ authentication_sasl_final.go | 19 ++++++++++ bind.go | 35 +++++++++++++++++ close.go | 25 ++++++++++++ command_complete.go | 18 +++++++++ copy_both_response.go | 25 ++++++++++++ copy_data.go | 18 +++++++++ copy_in_response.go | 25 ++++++++++++ copy_out_response.go | 25 ++++++++++++ data_row.go | 25 ++++++++++++ describe.go | 24 ++++++++++++ error_response.go | 67 +++++++++++++++++++++++++++++++++ function_call_response.go | 18 +++++++++ pgproto3.go | 17 ++++++++- ready_for_query.go | 21 +++++++++++ row_description.go | 31 +++++++++++++++ sasl_initial_response.go | 23 +++++++++++ sasl_response.go | 16 ++++++++ 19 files changed, 472 insertions(+), 1 deletion(-) diff --git a/authentication_md5_password.go b/authentication_md5_password.go index d505d264..b80bd992 100644 --- a/authentication_md5_password.go +++ b/authentication_md5_password.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -41,3 +42,24 @@ func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { dst = append(dst, src.Salt[:]...) return dst } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Salt string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.Salt) != 4 { + return errors.New("invalid salt size") + } + + copy(dst.Salt[:], []byte(msg.Salt)[:4]) + return nil +} diff --git a/authentication_sasl_continue.go b/authentication_sasl_continue.go index 1b918a6e..62a16c76 100644 --- a/authentication_sasl_continue.go +++ b/authentication_sasl_continue.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -46,3 +47,21 @@ func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { return dst } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/authentication_sasl_final.go b/authentication_sasl_final.go index 11d35660..de5e454a 100644 --- a/authentication_sasl_final.go +++ b/authentication_sasl_final.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -46,3 +47,21 @@ func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { return dst } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/bind.go b/bind.go index 52372095..57585c4d 100644 --- a/bind.go +++ b/bind.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "fmt" "github.com/jackc/pgio" ) @@ -181,3 +182,37 @@ func (src Bind) MarshalJSON() ([]byte, error) { ResultFormatCodes: src.ResultFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Bind) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + bind := &Bind{ + DestinationPortal: msg.DestinationPortal, + PreparedStatement: msg.PreparedStatement, + ParameterFormatCodes: msg.ParameterFormatCodes, + Parameters: make([][]byte, len(msg.Parameters)), + ResultFormatCodes: msg.ResultFormatCodes, + } + for n, parameter := range msg.Parameters { + bind.Parameters[n], err = getValueFromJSON(parameter) + if err != nil { + return fmt.Errorf("cannot get param %d: %w", n, err) + } + } + return nil +} diff --git a/close.go b/close.go index 38296909..a45f2b93 100644 --- a/close.go +++ b/close.go @@ -3,6 +3,7 @@ package pgproto3 import ( "bytes" "encoding/json" + "errors" "github.com/jackc/pgio" ) @@ -62,3 +63,27 @@ func (src Close) MarshalJSON() ([]byte, error) { Name: src.Name, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Close) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Close.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/command_complete.go b/command_complete.go index b5106fda..cdc49f39 100644 --- a/command_complete.go +++ b/command_complete.go @@ -51,3 +51,21 @@ func (src CommandComplete) MarshalJSON() ([]byte, error) { CommandTag: string(src.CommandTag), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CommandComplete) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + CommandTag string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.CommandTag = []byte(msg.CommandTag) + return nil +} diff --git a/copy_both_response.go b/copy_both_response.go index 2d58f820..fbd985d8 100644 --- a/copy_both_response.go +++ b/copy_both_response.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" "github.com/jackc/pgio" ) @@ -68,3 +69,27 @@ func (src CopyBothResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyBothResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/copy_data.go b/copy_data.go index 7d6002fe..128aa198 100644 --- a/copy_data.go +++ b/copy_data.go @@ -42,3 +42,21 @@ func (src CopyData) MarshalJSON() ([]byte, error) { Data: hex.EncodeToString(src.Data), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyData) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/copy_in_response.go b/copy_in_response.go index 5f2595b8..80733adc 100644 --- a/copy_in_response.go +++ b/copy_in_response.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" "github.com/jackc/pgio" ) @@ -69,3 +70,27 @@ func (src CopyInResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyInResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyInResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/copy_out_response.go b/copy_out_response.go index 8538dfc7..5e607e3a 100644 --- a/copy_out_response.go +++ b/copy_out_response.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" "github.com/jackc/pgio" ) @@ -69,3 +70,27 @@ func (src CopyOutResponse) MarshalJSON() ([]byte, error) { ColumnFormatCodes: src.ColumnFormatCodes, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyOutResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/data_row.go b/data_row.go index 5fa3c5d8..63768761 100644 --- a/data_row.go +++ b/data_row.go @@ -115,3 +115,28 @@ func (src DataRow) MarshalJSON() ([]byte, error) { Values: formattedValues, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *DataRow) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Values []map[string]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Values = make([][]byte, len(msg.Values)) + for n, parameter := range msg.Values { + var err error + dst.Values[n], err = getValueFromJSON(parameter) + if err != nil { + return err + } + } + return nil +} diff --git a/describe.go b/describe.go index 308f582e..0d825db1 100644 --- a/describe.go +++ b/describe.go @@ -3,6 +3,7 @@ package pgproto3 import ( "bytes" "encoding/json" + "errors" "github.com/jackc/pgio" ) @@ -62,3 +63,26 @@ func (src Describe) MarshalJSON() ([]byte, error) { Name: src.Name, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Describe) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Describe.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/error_response.go b/error_response.go index 4eb0a196..9bbd78f4 100644 --- a/error_response.go +++ b/error_response.go @@ -3,6 +3,8 @@ package pgproto3 import ( "bytes" "encoding/binary" + "encoding/json" + "fmt" "strconv" ) @@ -225,3 +227,68 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { return buf.Bytes() } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ErrorResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[string]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Severity = msg.Severity + dst.SeverityUnlocalized = msg.SeverityUnlocalized + dst.Code = msg.Code + dst.Message = msg.Message + dst.Detail = msg.Detail + dst.Hint = msg.Hint + dst.Position = msg.Position + dst.InternalPosition = msg.InternalPosition + dst.InternalQuery = msg.InternalQuery + dst.Where = msg.Where + dst.SchemaName = msg.SchemaName + dst.TableName = msg.TableName + dst.ColumnName = msg.ColumnName + dst.DataTypeName = msg.DataTypeName + dst.ConstraintName = msg.ConstraintName + dst.File = msg.File + dst.Line = msg.Line + dst.Routine = msg.Routine + + if msg.UnknownFields != nil { + dst.UnknownFields = map[byte]string{} + } + for k, v := range msg.UnknownFields { + if len(k) != 1 { + return fmt.Errorf("invalid UnknownFields field %q value", k) + } + dst.UnknownFields[k[0]] = v + } + + return nil +} diff --git a/function_call_response.go b/function_call_response.go index 5cc2d4d2..53d64222 100644 --- a/function_call_response.go +++ b/function_call_response.go @@ -81,3 +81,21 @@ func (src FunctionCallResponse) MarshalJSON() ([]byte, error) { Result: formattedValue, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Result map[string]string + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + dst.Result, err = getValueFromJSON(msg.Result) + return err +} diff --git a/pgproto3.go b/pgproto3.go index fe7b085b..5b39362c 100644 --- a/pgproto3.go +++ b/pgproto3.go @@ -1,6 +1,10 @@ package pgproto3 -import "fmt" +import ( + "encoding/hex" + "errors" + "fmt" +) // Message is the interface implemented by an object that can decode and encode // a particular PostgreSQL message. @@ -40,3 +44,14 @@ type invalidMessageFormatErr struct { func (e *invalidMessageFormatErr) Error() string { return fmt.Sprintf("%s body is invalid", e.messageType) } + +// getValueFromJSON gets the value from a protocol message representation in JSON. +func getValueFromJSON(v map[string]string) ([]byte, error) { + if text, ok := v["text"]; ok { + return []byte(text), nil + } + if binary, ok := v["binary"]; ok { + return hex.DecodeString(binary) + } + return nil, errors.New("unknown protocol representation") +} diff --git a/ready_for_query.go b/ready_for_query.go index 879afe39..67a39be3 100644 --- a/ready_for_query.go +++ b/ready_for_query.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/json" + "errors" ) type ReadyForQuery struct { @@ -38,3 +39,23 @@ func (src ReadyForQuery) MarshalJSON() ([]byte, error) { TxStatus: string(src.TxStatus), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + TxStatus string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.TxStatus) != 1 { + return errors.New("invalid length for ReadyForQuery.TxStatus") + } + dst.TxStatus = msg.TxStatus[0] + return nil +} diff --git a/row_description.go b/row_description.go index d9b8c7c9..a2e0d28e 100644 --- a/row_description.go +++ b/row_description.go @@ -132,3 +132,34 @@ func (src RowDescription) MarshalJSON() ([]byte, error) { Fields: src.Fields, }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *RowDescription) UnmarshalJSON(data []byte) error { + var msg struct { + Fields []struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 + } + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.Fields = make([]FieldDescription, len(msg.Fields)) + for n, field := range msg.Fields { + dst.Fields[n] = FieldDescription{ + Name: []byte(field.Name), + TableOID: field.TableOID, + TableAttributeNumber: field.TableAttributeNumber, + DataTypeOID: field.DataTypeOID, + DataTypeSize: field.DataTypeSize, + TypeModifier: field.TypeModifier, + Format: field.Format, + } + } + return nil +} diff --git a/sasl_initial_response.go b/sasl_initial_response.go index 0bf8a9e5..ce994c51 100644 --- a/sasl_initial_response.go +++ b/sasl_initial_response.go @@ -67,3 +67,26 @@ func (src SASLInitialResponse) MarshalJSON() ([]byte, error) { Data: hex.EncodeToString(src.Data), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + AuthMechanism string + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + decodedData, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.AuthMechanism = msg.AuthMechanism + dst.Data = decodedData + return nil +} diff --git a/sasl_response.go b/sasl_response.go index 21be6d75..df60c5f7 100644 --- a/sasl_response.go +++ b/sasl_response.go @@ -41,3 +41,19 @@ func (src SASLResponse) MarshalJSON() ([]byte, error) { Data: hex.EncodeToString(src.Data), }) } + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded + return nil +}