json: Implement json.Unmarshaler for messages.
This will allow using pgmockproxy output as ingestion data for pgmock.
This commit is contained in:
committed by
Jack Christensen
parent
1213b69774
commit
ba924e5715
@@ -2,6 +2,7 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
@@ -41,3 +42,24 @@ func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
||||
dst = append(dst, src.Salt[:]...)
|
||||
return dst
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Salt string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(msg.Salt) != 4 {
|
||||
return errors.New("invalid salt size")
|
||||
}
|
||||
|
||||
copy(dst.Salt[:], []byte(msg.Salt)[:4])
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
@@ -46,3 +47,21 @@ func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Data = []byte(msg.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
@@ -46,3 +47,21 @@ func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Data = []byte(msg.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
@@ -181,3 +182,37 @@ func (src Bind) MarshalJSON() ([]byte, error) {
|
||||
ResultFormatCodes: src.ResultFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *Bind) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters []map[string]string
|
||||
ResultFormatCodes []int16
|
||||
}
|
||||
err := json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bind := &Bind{
|
||||
DestinationPortal: msg.DestinationPortal,
|
||||
PreparedStatement: msg.PreparedStatement,
|
||||
ParameterFormatCodes: msg.ParameterFormatCodes,
|
||||
Parameters: make([][]byte, len(msg.Parameters)),
|
||||
ResultFormatCodes: msg.ResultFormatCodes,
|
||||
}
|
||||
for n, parameter := range msg.Parameters {
|
||||
bind.Parameters[n], err = getValueFromJSON(parameter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get param %d: %w", n, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package pgproto3
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
@@ -62,3 +63,27 @@ func (src Close) MarshalJSON() ([]byte, error) {
|
||||
Name: src.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *Close) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
ObjectType string
|
||||
Name string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msg.ObjectType) != 1 {
|
||||
return errors.New("invalid length for Close.ObjectType")
|
||||
}
|
||||
|
||||
dst.ObjectType = byte(msg.ObjectType[0])
|
||||
dst.Name = msg.Name
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -51,3 +51,21 @@ func (src CommandComplete) MarshalJSON() ([]byte, error) {
|
||||
CommandTag: string(src.CommandTag),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
CommandTag string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.CommandTag = []byte(msg.CommandTag)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
@@ -68,3 +69,27 @@ func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
OverallFormat string
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msg.OverallFormat) != 1 {
|
||||
return errors.New("invalid length for CopyBothResponse.OverallFormat")
|
||||
}
|
||||
|
||||
dst.OverallFormat = msg.OverallFormat[0]
|
||||
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -42,3 +42,21 @@ func (src CopyData) MarshalJSON() ([]byte, error) {
|
||||
Data: hex.EncodeToString(src.Data),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CopyData) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Data = []byte(msg.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
@@ -69,3 +70,27 @@ func (src CopyInResponse) MarshalJSON() ([]byte, error) {
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CopyInResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
OverallFormat string
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msg.OverallFormat) != 1 {
|
||||
return errors.New("invalid length for CopyInResponse.OverallFormat")
|
||||
}
|
||||
|
||||
dst.OverallFormat = msg.OverallFormat[0]
|
||||
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
@@ -69,3 +70,27 @@ func (src CopyOutResponse) MarshalJSON() ([]byte, error) {
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
OverallFormat string
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msg.OverallFormat) != 1 {
|
||||
return errors.New("invalid length for CopyOutResponse.OverallFormat")
|
||||
}
|
||||
|
||||
dst.OverallFormat = msg.OverallFormat[0]
|
||||
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||
return nil
|
||||
}
|
||||
|
||||
+25
@@ -115,3 +115,28 @@ func (src DataRow) MarshalJSON() ([]byte, error) {
|
||||
Values: formattedValues,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *DataRow) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Values []map[string]string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Values = make([][]byte, len(msg.Values))
|
||||
for n, parameter := range msg.Values {
|
||||
var err error
|
||||
dst.Values[n], err = getValueFromJSON(parameter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+24
@@ -3,6 +3,7 @@ package pgproto3
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
@@ -62,3 +63,26 @@ func (src Describe) MarshalJSON() ([]byte, error) {
|
||||
Name: src.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *Describe) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
ObjectType string
|
||||
Name string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(msg.ObjectType) != 1 {
|
||||
return errors.New("invalid length for Describe.ObjectType")
|
||||
}
|
||||
|
||||
dst.ObjectType = byte(msg.ObjectType[0])
|
||||
dst.Name = msg.Name
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package pgproto3
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@@ -225,3 +227,68 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *ErrorResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Severity string
|
||||
SeverityUnlocalized string // only in 9.6 and greater
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
|
||||
UnknownFields map[string]string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Severity = msg.Severity
|
||||
dst.SeverityUnlocalized = msg.SeverityUnlocalized
|
||||
dst.Code = msg.Code
|
||||
dst.Message = msg.Message
|
||||
dst.Detail = msg.Detail
|
||||
dst.Hint = msg.Hint
|
||||
dst.Position = msg.Position
|
||||
dst.InternalPosition = msg.InternalPosition
|
||||
dst.InternalQuery = msg.InternalQuery
|
||||
dst.Where = msg.Where
|
||||
dst.SchemaName = msg.SchemaName
|
||||
dst.TableName = msg.TableName
|
||||
dst.ColumnName = msg.ColumnName
|
||||
dst.DataTypeName = msg.DataTypeName
|
||||
dst.ConstraintName = msg.ConstraintName
|
||||
dst.File = msg.File
|
||||
dst.Line = msg.Line
|
||||
dst.Routine = msg.Routine
|
||||
|
||||
if msg.UnknownFields != nil {
|
||||
dst.UnknownFields = map[byte]string{}
|
||||
}
|
||||
for k, v := range msg.UnknownFields {
|
||||
if len(k) != 1 {
|
||||
return fmt.Errorf("invalid UnknownFields field %q value", k)
|
||||
}
|
||||
dst.UnknownFields[k[0]] = v
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -81,3 +81,21 @@ func (src FunctionCallResponse) MarshalJSON() ([]byte, error) {
|
||||
Result: formattedValue,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Result map[string]string
|
||||
}
|
||||
err := json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Result, err = getValueFromJSON(msg.Result)
|
||||
return err
|
||||
}
|
||||
|
||||
+16
-1
@@ -1,6 +1,10 @@
|
||||
package pgproto3
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Message is the interface implemented by an object that can decode and encode
|
||||
// a particular PostgreSQL message.
|
||||
@@ -40,3 +44,14 @@ type invalidMessageFormatErr struct {
|
||||
func (e *invalidMessageFormatErr) Error() string {
|
||||
return fmt.Sprintf("%s body is invalid", e.messageType)
|
||||
}
|
||||
|
||||
// getValueFromJSON gets the value from a protocol message representation in JSON.
|
||||
func getValueFromJSON(v map[string]string) ([]byte, error) {
|
||||
if text, ok := v["text"]; ok {
|
||||
return []byte(text), nil
|
||||
}
|
||||
if binary, ok := v["binary"]; ok {
|
||||
return hex.DecodeString(binary)
|
||||
}
|
||||
return nil, errors.New("unknown protocol representation")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type ReadyForQuery struct {
|
||||
@@ -38,3 +39,23 @@ func (src ReadyForQuery) MarshalJSON() ([]byte, error) {
|
||||
TxStatus: string(src.TxStatus),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
TxStatus string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(msg.TxStatus) != 1 {
|
||||
return errors.New("invalid length for ReadyForQuery.TxStatus")
|
||||
}
|
||||
dst.TxStatus = msg.TxStatus[0]
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -132,3 +132,34 @@ func (src RowDescription) MarshalJSON() ([]byte, error) {
|
||||
Fields: src.Fields,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *RowDescription) UnmarshalJSON(data []byte) error {
|
||||
var msg struct {
|
||||
Fields []struct {
|
||||
Name string
|
||||
TableOID uint32
|
||||
TableAttributeNumber uint16
|
||||
DataTypeOID uint32
|
||||
DataTypeSize int16
|
||||
TypeModifier int32
|
||||
Format int16
|
||||
}
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Fields = make([]FieldDescription, len(msg.Fields))
|
||||
for n, field := range msg.Fields {
|
||||
dst.Fields[n] = FieldDescription{
|
||||
Name: []byte(field.Name),
|
||||
TableOID: field.TableOID,
|
||||
TableAttributeNumber: field.TableAttributeNumber,
|
||||
DataTypeOID: field.DataTypeOID,
|
||||
DataTypeSize: field.DataTypeSize,
|
||||
TypeModifier: field.TypeModifier,
|
||||
Format: field.Format,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -67,3 +67,26 @@ func (src SASLInitialResponse) MarshalJSON() ([]byte, error) {
|
||||
Data: hex.EncodeToString(src.Data),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
AuthMechanism string
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
decodedData, err := hex.DecodeString(msg.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.AuthMechanism = msg.AuthMechanism
|
||||
dst.Data = decodedData
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -41,3 +41,19 @@ func (src SASLResponse) MarshalJSON() ([]byte, error) {
|
||||
Data: hex.EncodeToString(src.Data),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *SASLResponse) UnmarshalJSON(data []byte) error {
|
||||
var msg struct {
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
decoded, err := hex.DecodeString(msg.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Data = decoded
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user