2
0

json: Implement json.Unmarshaler for messages.

This will allow using pgmockproxy output as ingestion data for pgmock.
This commit is contained in:
Henrique Vicente
2021-05-16 02:05:24 +02:00
committed by Jack Christensen
parent 1213b69774
commit ba924e5715
19 changed files with 472 additions and 1 deletions
+22
View File
@@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
@@ -41,3 +42,24 @@ func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, src.Salt[:]...)
return dst
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Salt string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.Salt) != 4 {
return errors.New("invalid salt size")
}
copy(dst.Salt[:], []byte(msg.Salt)[:4])
return nil
}
+19
View File
@@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
@@ -46,3 +47,21 @@ func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
return dst
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}
+19
View File
@@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
@@ -46,3 +47,21 @@ func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
return dst
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}
+35
View File
@@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/jackc/pgio"
)
@@ -181,3 +182,37 @@ func (src Bind) MarshalJSON() ([]byte, error) {
ResultFormatCodes: src.ResultFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Bind) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters []map[string]string
ResultFormatCodes []int16
}
err := json.Unmarshal(data, &msg)
if err != nil {
return err
}
bind := &Bind{
DestinationPortal: msg.DestinationPortal,
PreparedStatement: msg.PreparedStatement,
ParameterFormatCodes: msg.ParameterFormatCodes,
Parameters: make([][]byte, len(msg.Parameters)),
ResultFormatCodes: msg.ResultFormatCodes,
}
for n, parameter := range msg.Parameters {
bind.Parameters[n], err = getValueFromJSON(parameter)
if err != nil {
return fmt.Errorf("cannot get param %d: %w", n, err)
}
}
return nil
}
+25
View File
@@ -3,6 +3,7 @@ package pgproto3
import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
@@ -62,3 +63,27 @@ func (src Close) MarshalJSON() ([]byte, error) {
Name: src.Name,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Close) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
ObjectType string
Name string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.ObjectType) != 1 {
return errors.New("invalid length for Close.ObjectType")
}
dst.ObjectType = byte(msg.ObjectType[0])
dst.Name = msg.Name
return nil
}
+18
View File
@@ -51,3 +51,21 @@ func (src CommandComplete) MarshalJSON() ([]byte, error) {
CommandTag: string(src.CommandTag),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
CommandTag string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.CommandTag = []byte(msg.CommandTag)
return nil
}
+25
View File
@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
@@ -68,3 +69,27 @@ func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyBothResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}
+18
View File
@@ -42,3 +42,21 @@ func (src CopyData) MarshalJSON() ([]byte, error) {
Data: hex.EncodeToString(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyData) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}
+25
View File
@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
@@ -69,3 +70,27 @@ func (src CopyInResponse) MarshalJSON() ([]byte, error) {
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyInResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyInResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}
+25
View File
@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
@@ -69,3 +70,27 @@ func (src CopyOutResponse) MarshalJSON() ([]byte, error) {
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyOutResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}
+25
View File
@@ -115,3 +115,28 @@ func (src DataRow) MarshalJSON() ([]byte, error) {
Values: formattedValues,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *DataRow) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Values []map[string]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Values = make([][]byte, len(msg.Values))
for n, parameter := range msg.Values {
var err error
dst.Values[n], err = getValueFromJSON(parameter)
if err != nil {
return err
}
}
return nil
}
+24
View File
@@ -3,6 +3,7 @@ package pgproto3
import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
@@ -62,3 +63,26 @@ func (src Describe) MarshalJSON() ([]byte, error) {
Name: src.Name,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Describe) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
ObjectType string
Name string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.ObjectType) != 1 {
return errors.New("invalid length for Describe.ObjectType")
}
dst.ObjectType = byte(msg.ObjectType[0])
dst.Name = msg.Name
return nil
}
+67
View File
@@ -3,6 +3,8 @@ package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"strconv"
)
@@ -225,3 +227,68 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
return buf.Bytes()
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *ErrorResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[string]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Severity = msg.Severity
dst.SeverityUnlocalized = msg.SeverityUnlocalized
dst.Code = msg.Code
dst.Message = msg.Message
dst.Detail = msg.Detail
dst.Hint = msg.Hint
dst.Position = msg.Position
dst.InternalPosition = msg.InternalPosition
dst.InternalQuery = msg.InternalQuery
dst.Where = msg.Where
dst.SchemaName = msg.SchemaName
dst.TableName = msg.TableName
dst.ColumnName = msg.ColumnName
dst.DataTypeName = msg.DataTypeName
dst.ConstraintName = msg.ConstraintName
dst.File = msg.File
dst.Line = msg.Line
dst.Routine = msg.Routine
if msg.UnknownFields != nil {
dst.UnknownFields = map[byte]string{}
}
for k, v := range msg.UnknownFields {
if len(k) != 1 {
return fmt.Errorf("invalid UnknownFields field %q value", k)
}
dst.UnknownFields[k[0]] = v
}
return nil
}
+18
View File
@@ -81,3 +81,21 @@ func (src FunctionCallResponse) MarshalJSON() ([]byte, error) {
Result: formattedValue,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Result map[string]string
}
err := json.Unmarshal(data, &msg)
if err != nil {
return err
}
dst.Result, err = getValueFromJSON(msg.Result)
return err
}
+16 -1
View File
@@ -1,6 +1,10 @@
package pgproto3
import "fmt"
import (
"encoding/hex"
"errors"
"fmt"
)
// Message is the interface implemented by an object that can decode and encode
// a particular PostgreSQL message.
@@ -40,3 +44,14 @@ type invalidMessageFormatErr struct {
func (e *invalidMessageFormatErr) Error() string {
return fmt.Sprintf("%s body is invalid", e.messageType)
}
// getValueFromJSON gets the value from a protocol message representation in JSON.
func getValueFromJSON(v map[string]string) ([]byte, error) {
if text, ok := v["text"]; ok {
return []byte(text), nil
}
if binary, ok := v["binary"]; ok {
return hex.DecodeString(binary)
}
return nil, errors.New("unknown protocol representation")
}
+21
View File
@@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/json"
"errors"
)
type ReadyForQuery struct {
@@ -38,3 +39,23 @@ func (src ReadyForQuery) MarshalJSON() ([]byte, error) {
TxStatus: string(src.TxStatus),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
TxStatus string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.TxStatus) != 1 {
return errors.New("invalid length for ReadyForQuery.TxStatus")
}
dst.TxStatus = msg.TxStatus[0]
return nil
}
+31
View File
@@ -132,3 +132,34 @@ func (src RowDescription) MarshalJSON() ([]byte, error) {
Fields: src.Fields,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *RowDescription) UnmarshalJSON(data []byte) error {
var msg struct {
Fields []struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier int32
Format int16
}
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Fields = make([]FieldDescription, len(msg.Fields))
for n, field := range msg.Fields {
dst.Fields[n] = FieldDescription{
Name: []byte(field.Name),
TableOID: field.TableOID,
TableAttributeNumber: field.TableAttributeNumber,
DataTypeOID: field.DataTypeOID,
DataTypeSize: field.DataTypeSize,
TypeModifier: field.TypeModifier,
Format: field.Format,
}
}
return nil
}
+23
View File
@@ -67,3 +67,26 @@ func (src SASLInitialResponse) MarshalJSON() ([]byte, error) {
Data: hex.EncodeToString(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
AuthMechanism string
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
decodedData, err := hex.DecodeString(msg.Data)
if err != nil {
return err
}
dst.AuthMechanism = msg.AuthMechanism
dst.Data = decodedData
return nil
}
+16
View File
@@ -41,3 +41,19 @@ func (src SASLResponse) MarshalJSON() ([]byte, error) {
Data: hex.EncodeToString(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *SASLResponse) UnmarshalJSON(data []byte) error {
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
decoded, err := hex.DecodeString(msg.Data)
if err != nil {
return err
}
dst.Data = decoded
return nil
}