From 4e2900b774649cbc58686ac047bf84f54aff4e18 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 10:02:38 -0500 Subject: [PATCH] Introduce pgproto3 package pgproto3 will wrap the message encoding and decoding for the PostgreSQL frontend/backend protocol version 3. --- authentication.go | 54 +++++++++++ backend_key_data.go | 47 +++++++++ big_endian.go | 37 +++++++ bind_complete.go | 29 ++++++ close_complete.go | 29 ++++++ command_complete.go | 47 +++++++++ copy_both_response.go | 64 +++++++++++++ copy_data.go | 41 ++++++++ copy_in_response.go | 64 +++++++++++++ copy_out_response.go | 64 +++++++++++++ data_row.go | 103 ++++++++++++++++++++ empty_query_response.go | 29 ++++++ error_response.go | 197 ++++++++++++++++++++++++++++++++++++++ frontend.go | 70 ++++++++++++++ function_call_response.go | 73 ++++++++++++++ no_data.go | 29 ++++++ notice_response.go | 13 +++ notification_response.go | 65 +++++++++++++ parameter_description.go | 60 ++++++++++++ parameter_status.go | 62 ++++++++++++ parse_complete.go | 29 ++++++ pgproto3.go | 88 +++++++++++++++++ query.go | 43 +++++++++ ready_for_query.go | 35 +++++++ row_description.go | 101 +++++++++++++++++++ 25 files changed, 1473 insertions(+) create mode 100644 authentication.go create mode 100644 backend_key_data.go create mode 100644 big_endian.go create mode 100644 bind_complete.go create mode 100644 close_complete.go create mode 100644 command_complete.go create mode 100644 copy_both_response.go create mode 100644 copy_data.go create mode 100644 copy_in_response.go create mode 100644 copy_out_response.go create mode 100644 data_row.go create mode 100644 empty_query_response.go create mode 100644 error_response.go create mode 100644 frontend.go create mode 100644 function_call_response.go create mode 100644 no_data.go create mode 100644 notice_response.go create mode 100644 notification_response.go create mode 100644 parameter_description.go create mode 100644 parameter_status.go create mode 100644 parse_complete.go create mode 100644 pgproto3.go create mode 100644 query.go create mode 100644 ready_for_query.go create mode 100644 row_description.go diff --git a/authentication.go b/authentication.go new file mode 100644 index 00000000..e265a247 --- /dev/null +++ b/authentication.go @@ -0,0 +1,54 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +const ( + AuthTypeOk = 0 + AuthTypeCleartextPassword = 3 + AuthTypeMD5Password = 5 +) + +type Authentication struct { + Type uint32 + + // MD5Password fields + Salt [4]byte +} + +func (*Authentication) Backend() {} + +func (dst *Authentication) UnmarshalBinary(src []byte) error { + *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} + + switch dst.Type { + case AuthTypeOk: + case AuthTypeCleartextPassword: + case AuthTypeMD5Password: + copy(dst.Salt[:], src[4:8]) + default: + return fmt.Errorf("unknown authentication type: %d", dst.Type) + } + + return nil +} + +func (src *Authentication) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('R') + buf.Write(bigEndian.Uint32(0)) + buf.Write(bigEndian.Uint32(src.Type)) + + switch src.Type { + case AuthTypeMD5Password: + buf.Write(src.Salt[:]) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} diff --git a/backend_key_data.go b/backend_key_data.go new file mode 100644 index 00000000..5d8eb496 --- /dev/null +++ b/backend_key_data.go @@ -0,0 +1,47 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type BackendKeyData struct { + ProcessID uint32 + SecretKey uint32 +} + +func (*BackendKeyData) Backend() {} + +func (dst *BackendKeyData) UnmarshalBinary(src []byte) error { + if len(src) != 8 { + return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} + } + + dst.ProcessID = binary.BigEndian.Uint32(src[:4]) + dst.SecretKey = binary.BigEndian.Uint32(src[4:]) + + return nil +} + +func (src *BackendKeyData) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('K') + buf.Write(bigEndian.Uint32(12)) + buf.Write(bigEndian.Uint32(src.ProcessID)) + buf.Write(bigEndian.Uint32(src.SecretKey)) + return buf.Bytes(), nil +} + +func (src *BackendKeyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "BackendKeyData", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/big_endian.go b/big_endian.go new file mode 100644 index 00000000..f7bdb97e --- /dev/null +++ b/big_endian.go @@ -0,0 +1,37 @@ +package pgproto3 + +import ( + "encoding/binary" +) + +type BigEndianBuf [8]byte + +func (b BigEndianBuf) Int16(n int16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, uint16(n)) + return buf +} + +func (b BigEndianBuf) Uint16(n uint16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, n) + return buf +} + +func (b BigEndianBuf) Int32(n int32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, uint32(n)) + return buf +} + +func (b BigEndianBuf) Uint32(n uint32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, n) + return buf +} + +func (b BigEndianBuf) Int64(n int64) []byte { + buf := b[0:8] + binary.BigEndian.PutUint64(buf, uint64(n)) + return buf +} diff --git a/bind_complete.go b/bind_complete.go new file mode 100644 index 00000000..756a30e6 --- /dev/null +++ b/bind_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type BindComplete struct{} + +func (*BindComplete) Backend() {} + +func (dst *BindComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *BindComplete) MarshalBinary() ([]byte, error) { + return []byte{'2', 0, 0, 0, 4}, nil +} + +func (src *BindComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "BindComplete", + }) +} diff --git a/close_complete.go b/close_complete.go new file mode 100644 index 00000000..fd6ff180 --- /dev/null +++ b/close_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CloseComplete struct{} + +func (*CloseComplete) Backend() {} + +func (dst *CloseComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *CloseComplete) MarshalBinary() ([]byte, error) { + return []byte{'3', 0, 0, 0, 4}, nil +} + +func (src *CloseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CloseComplete", + }) +} diff --git a/command_complete.go b/command_complete.go new file mode 100644 index 00000000..ac60153e --- /dev/null +++ b/command_complete.go @@ -0,0 +1,47 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type CommandComplete struct { + CommandTag string +} + +func (*CommandComplete) Backend() {} + +func (dst *CommandComplete) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.CommandTag = string(b[:len(b)-1]) + + return nil +} + +func (src *CommandComplete) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('C') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.CommandTag) + 1))) + + buf.WriteString(src.CommandTag) + buf.WriteByte(0) + + return buf.Bytes(), nil +} + +func (src *CommandComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + CommandTag string + }{ + Type: "CommandComplete", + CommandTag: src.CommandTag, + }) +} diff --git a/copy_both_response.go b/copy_both_response.go new file mode 100644 index 00000000..2a4c58af --- /dev/null +++ b/copy_both_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyBothResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyBothResponse) Backend() {} + +func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyBothResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('W') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyBothResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/copy_data.go b/copy_data.go new file mode 100644 index 00000000..b9ea6272 --- /dev/null +++ b/copy_data.go @@ -0,0 +1,41 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" +) + +type CopyData struct { + Data []byte +} + +func (*CopyData) Backend() {} +func (*CopyData) Frontend() {} + +func (dst *CopyData) UnmarshalBinary(src []byte) error { + dst.Data = make([]byte, len(src)) + copy(dst.Data, src) + return nil +} + +func (src *CopyData) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('d') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.Data)))) + buf.Write(src.Data) + + return buf.Bytes(), nil +} + +func (src *CopyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "CopyData", + Data: hex.EncodeToString(src.Data), + }) +} diff --git a/copy_in_response.go b/copy_in_response.go new file mode 100644 index 00000000..63868c7a --- /dev/null +++ b/copy_in_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyInResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyInResponse) Backend() {} + +func (dst *CopyInResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyInResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('G') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyInResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyInResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/copy_out_response.go b/copy_out_response.go new file mode 100644 index 00000000..e46d9e8f --- /dev/null +++ b/copy_out_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyOutResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyOutResponse) Backend() {} + +func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyOutResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('H') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyOutResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/data_row.go b/data_row.go new file mode 100644 index 00000000..c95861b9 --- /dev/null +++ b/data_row.go @@ -0,0 +1,103 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type DataRow struct { + Values [][]byte +} + +func (*DataRow) Backend() {} + +func (dst *DataRow) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + dst.Values = make([][]byte, fieldCount) + + for i := 0; i < fieldCount; i++ { + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4)))) + + // null + if msgSize == -1 { + continue + } + + value := make([]byte, msgSize) + _, err := buf.Read(value) + if err != nil { + return err + } + + dst.Values[i] = value + } + + return nil +} + +func (src *DataRow) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('D') + buf.Write(bigEndian.Uint32(0)) + + buf.Write(bigEndian.Uint16(uint16(len(src.Values)))) + + for _, v := range src.Values { + if v == nil { + buf.Write(bigEndian.Int32(-1)) + continue + } + + buf.Write(bigEndian.Int32(int32(len(v)))) + buf.Write(v) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *DataRow) MarshalJSON() ([]byte, error) { + formattedValues := make([]map[string]string, len(src.Values)) + for i, v := range src.Values { + if v == nil { + continue + } + + var hasNonPrintable bool + for _, b := range v { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)} + } else { + formattedValues[i] = map[string]string{"text": string(v)} + } + } + + return json.Marshal(struct { + Type string + Values []map[string]string + }{ + Type: "DataRow", + Values: formattedValues, + }) +} diff --git a/empty_query_response.go b/empty_query_response.go new file mode 100644 index 00000000..de6e6272 --- /dev/null +++ b/empty_query_response.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type EmptyQueryResponse struct{} + +func (*EmptyQueryResponse) Backend() {} + +func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *EmptyQueryResponse) MarshalBinary() ([]byte, error) { + return []byte{'I', 0, 0, 0, 4}, nil +} + +func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "EmptyQueryResponse", + }) +} diff --git a/error_response.go b/error_response.go new file mode 100644 index 00000000..82e408d7 --- /dev/null +++ b/error_response.go @@ -0,0 +1,197 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "strconv" +) + +type ErrorResponse struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string +} + +func (*ErrorResponse) Backend() {} + +func (dst *ErrorResponse) UnmarshalBinary(src []byte) error { + *dst = ErrorResponse{} + + buf := bytes.NewBuffer(src) + + for { + k, err := buf.ReadByte() + if err != nil { + return err + } + if k == 0 { + break + } + + vb, err := buf.ReadBytes(0) + if err != nil { + return err + } + v := string(vb[:len(vb)-1]) + + switch k { + case 'S': + dst.Severity = v + case 'C': + dst.Code = v + case 'M': + dst.Message = v + case 'D': + dst.Detail = v + case 'H': + dst.Hint = v + case 'P': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Position = int32(n) + case 'p': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.InternalPosition = int32(n) + case 'q': + dst.InternalQuery = v + case 'W': + dst.Where = v + case 's': + dst.SchemaName = v + case 't': + dst.TableName = v + case 'c': + dst.ColumnName = v + case 'd': + dst.DataTypeName = v + case 'n': + dst.ConstraintName = v + case 'F': + dst.File = v + case 'L': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Line = int32(n) + case 'R': + dst.Routine = v + + default: + if dst.UnknownFields == nil { + dst.UnknownFields = make(map[byte]string) + } + dst.UnknownFields[k] = v + } + } + + return nil +} + +func (src *ErrorResponse) MarshalBinary() ([]byte, error) { + return src.marshalBinary('E') +} + +func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte(typeByte) + buf.Write(bigEndian.Uint32(0)) + + if src.Severity != "" { + buf.WriteString(src.Severity) + buf.WriteByte(0) + } + if src.Code != "" { + buf.WriteString(src.Code) + buf.WriteByte(0) + } + if src.Message != "" { + buf.WriteString(src.Message) + buf.WriteByte(0) + } + if src.Detail != "" { + buf.WriteString(src.Detail) + buf.WriteByte(0) + } + if src.Hint != "" { + buf.WriteString(src.Hint) + buf.WriteByte(0) + } + if src.Position != 0 { + buf.WriteString(strconv.Itoa(int(src.Position))) + buf.WriteByte(0) + } + if src.InternalPosition != 0 { + buf.WriteString(strconv.Itoa(int(src.InternalPosition))) + buf.WriteByte(0) + } + if src.InternalQuery != "" { + buf.WriteString(src.InternalQuery) + buf.WriteByte(0) + } + if src.Where != "" { + buf.WriteString(src.Where) + buf.WriteByte(0) + } + if src.SchemaName != "" { + buf.WriteString(src.SchemaName) + buf.WriteByte(0) + } + if src.TableName != "" { + buf.WriteString(src.TableName) + buf.WriteByte(0) + } + if src.ColumnName != "" { + buf.WriteString(src.ColumnName) + buf.WriteByte(0) + } + if src.DataTypeName != "" { + buf.WriteString(src.DataTypeName) + buf.WriteByte(0) + } + if src.ConstraintName != "" { + buf.WriteString(src.ConstraintName) + buf.WriteByte(0) + } + if src.File != "" { + buf.WriteString(src.File) + buf.WriteByte(0) + } + if src.Line != 0 { + buf.WriteString(strconv.Itoa(int(src.Line))) + buf.WriteByte(0) + } + if src.Routine != "" { + buf.WriteString(src.Routine) + buf.WriteByte(0) + } + + for k, v := range src.UnknownFields { + buf.WriteByte(k) + buf.WriteByte(0) + buf.WriteString(v) + buf.WriteByte(0) + } + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} diff --git a/frontend.go b/frontend.go new file mode 100644 index 00000000..c1dec461 --- /dev/null +++ b/frontend.go @@ -0,0 +1,70 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/jackc/pgx/chunkreader" +) + +type Frontend struct { + cr *chunkreader.ChunkReader + w io.Writer +} + +func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { + cr := chunkreader.NewChunkReader(r) + return &Frontend{cr: cr, w: w}, nil +} + +func (b *Frontend) Send(msg FrontendMessage) error { + return errors.New("not implemented") +} + +func (b *Frontend) Receive() (BackendMessage, error) { + backendMessages := map[byte]BackendMessage{ + '1': &ParseComplete{}, + '2': &BindComplete{}, + '3': &CloseComplete{}, + 'A': &NotificationResponse{}, + 'C': &CommandComplete{}, + 'd': &CopyData{}, + 'D': &DataRow{}, + 'E': &ErrorResponse{}, + 'G': &CopyInResponse{}, + 'H': &CopyOutResponse{}, + 'I': &EmptyQueryResponse{}, + 'K': &BackendKeyData{}, + 'n': &NoData{}, + 'N': &NoticeResponse{}, + 'R': &Authentication{}, + 'S': &ParameterStatus{}, + 't': &ParameterDescription{}, + 'T': &RowDescription{}, + 'V': &FunctionCallResponse{}, + 'W': &CopyBothResponse{}, + 'Z': &ReadyForQuery{}, + } + + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + msgType := header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + + msgBody, err := b.cr.Next(bodyLen) + if err != nil { + return nil, err + } + + if msg, ok := backendMessages[msgType]; ok { + err = msg.UnmarshalBinary(msgBody) + return msg, err + } + + return nil, fmt.Errorf("unknown message type: %c", msgType) +} diff --git a/function_call_response.go b/function_call_response.go new file mode 100644 index 00000000..5c692b36 --- /dev/null +++ b/function_call_response.go @@ -0,0 +1,73 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type FunctionCallResponse struct { + Result []byte +} + +func (*FunctionCallResponse) Backend() {} + +func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + resultSize := int(binary.BigEndian.Uint32(buf.Next(4))) + if buf.Len() != resultSize { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + + dst.Result = make([]byte, resultSize) + copy(dst.Result, buf.Bytes()) + + return nil +} + +func (src *FunctionCallResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('V') + buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Result)))) + + if src.Result == nil { + buf.Write(bigEndian.Int32(-1)) + } else { + buf.Write(bigEndian.Int32(int32(len(src.Result)))) + buf.Write(src.Result) + } + + return buf.Bytes(), nil +} + +func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { + var formattedValue map[string]string + var hasNonPrintable bool + for _, b := range src.Result { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)} + } else { + formattedValue = map[string]string{"text": string(src.Result)} + } + + return json.Marshal(struct { + Type string + Result map[string]string + }{ + Type: "FunctionCallResponse", + Result: formattedValue, + }) +} diff --git a/no_data.go b/no_data.go new file mode 100644 index 00000000..47ebf28e --- /dev/null +++ b/no_data.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type NoData struct{} + +func (*NoData) Backend() {} + +func (dst *NoData) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *NoData) MarshalBinary() ([]byte, error) { + return []byte{'n', 0, 0, 0, 4}, nil +} + +func (src *NoData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "NoData", + }) +} diff --git a/notice_response.go b/notice_response.go new file mode 100644 index 00000000..767c9a67 --- /dev/null +++ b/notice_response.go @@ -0,0 +1,13 @@ +package pgproto3 + +type NoticeResponse ErrorResponse + +func (*NoticeResponse) Backend() {} + +func (dst *NoticeResponse) UnmarshalBinary(src []byte) error { + return (*ErrorResponse)(dst).UnmarshalBinary(src) +} + +func (src *NoticeResponse) MarshalBinary() ([]byte, error) { + return (*ErrorResponse)(src).marshalBinary('N') +} diff --git a/notification_response.go b/notification_response.go new file mode 100644 index 00000000..4ae8bab3 --- /dev/null +++ b/notification_response.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type NotificationResponse struct { + PID uint32 + Channel string + Payload string +} + +func (*NotificationResponse) Backend() {} + +func (dst *NotificationResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + pid := binary.BigEndian.Uint32(buf.Next(4)) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + channel := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + payload := string(b[:len(b)-1]) + + *dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload} + return nil +} + +func (src *NotificationResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('A') + buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Channel) + len(src.Payload)))) + + buf.WriteString(src.Channel) + buf.WriteByte(0) + buf.WriteString(src.Payload) + buf.WriteByte(0) + + return buf.Bytes(), nil +} + +func (src *NotificationResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + PID uint32 + Channel string + Payload string + }{ + Type: "NotificationResponse", + PID: src.PID, + Channel: src.Channel, + Payload: src.Payload, + }) +} diff --git a/parameter_description.go b/parameter_description.go new file mode 100644 index 00000000..40d92c50 --- /dev/null +++ b/parameter_description.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type ParameterDescription struct { + ParameterOIDs []uint32 +} + +func (*ParameterDescription) Backend() {} + +func (dst *ParameterDescription) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "ParameterDescription"} + } + + // Reported parameter count will be incorrect when number of args is greater than uint16 + buf.Next(2) + // Instead infer parameter count by remaining size of message + parameterCount := buf.Len() / 4 + + *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} + + for i := 0; i < parameterCount; i++ { + dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) + } + + return nil +} + +func (src *ParameterDescription) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('t') + buf.Write(bigEndian.Uint32(uint32(4 + 2 + 4*len(src.ParameterOIDs)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) + + for _, oid := range src.ParameterOIDs { + buf.Write(bigEndian.Uint32(oid)) + } + + return buf.Bytes(), nil +} + +func (src *ParameterDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ParameterOIDs []uint32 + }{ + Type: "ParameterDescription", + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/parameter_status.go b/parameter_status.go new file mode 100644 index 00000000..b8ce7f8d --- /dev/null +++ b/parameter_status.go @@ -0,0 +1,62 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type ParameterStatus struct { + Name string + Value string +} + +func (*ParameterStatus) Backend() {} + +func (dst *ParameterStatus) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + name := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + value := string(b[:len(b)-1]) + + *dst = ParameterStatus{Name: name, Value: value} + return nil +} + +func (src *ParameterStatus) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('S') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.Name) + buf.WriteByte(0) + buf.WriteString(src.Value) + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Value string + }{ + Type: "ParameterStatus", + Name: ps.Name, + Value: ps.Value, + }) +} diff --git a/parse_complete.go b/parse_complete.go new file mode 100644 index 00000000..24951e3d --- /dev/null +++ b/parse_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ParseComplete struct{} + +func (*ParseComplete) Backend() {} + +func (dst *ParseComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *ParseComplete) MarshalBinary() ([]byte, error) { + return []byte{'1', 0, 0, 0, 4}, nil +} + +func (src *ParseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "ParseComplete", + }) +} diff --git a/pgproto3.go b/pgproto3.go new file mode 100644 index 00000000..a9221239 --- /dev/null +++ b/pgproto3.go @@ -0,0 +1,88 @@ +package pgproto3 + +import "fmt" + +type Message interface { + UnmarshalBinary(data []byte) error + MarshalBinary() (data []byte, err error) +} + +type FrontendMessage interface { + Message + Frontend() // no-op method to distinguish frontend from backend methods +} + +type BackendMessage interface { + Message + Backend() // no-op method to distinguish frontend from backend methods +} + +// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) { +// switch typeByte { +// case '1': +// return ParseParseComplete(body) +// case '2': +// return ParseBindComplete(body) +// case 'C': +// return ParseCommandComplete(body) +// case 'D': +// return ParseDataRow(body) +// case 'E': +// return ParseErrorResponse(body) +// case 'K': +// return ParseBackendKeyData(body) +// case 'R': +// return ParseAuthentication(body) +// case 'S': +// return ParseParameterStatus(body) +// case 'T': +// return ParseRowDescription(body) +// case 't': +// return ParseParameterDescription(body) +// case 'Z': +// return ParseReadyForQuery(body) +// default: +// return ParseUnknownMessage(typeByte, body) +// } +// } + +// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) { +// switch typeByte { +// case 'B': +// return ParseBind(body) +// case 'D': +// return ParseDescribe(body) +// case 'E': +// return ParseExecute(body) +// case 'P': +// return ParseParse(body) +// case 'p': +// return ParsePasswordMessage(body) +// case 'Q': +// return ParseQuery(body) +// case 'S': +// return ParseSync(body) +// case 'X': +// return ParseTerminate(body) +// default: +// return ParseUnknownMessage(typeByte, body) +// } +// } + +type invalidMessageLenErr struct { + messageType string + expectedLen int + actualLen int +} + +func (e *invalidMessageLenErr) Error() string { + return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen) +} + +type invalidMessageFormatErr struct { + messageType string +} + +func (e *invalidMessageFormatErr) Error() string { + return fmt.Sprintf("%s body is invalid", e.messageType) +} diff --git a/query.go b/query.go new file mode 100644 index 00000000..a3fc32eb --- /dev/null +++ b/query.go @@ -0,0 +1,43 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type Query struct { + String string +} + +func (*Query) Frontend() {} + +func (dst *Query) UnmarshalBinary(src []byte) error { + i := bytes.IndexByte(src, 0) + if i != len(src)-1 { + return &invalidMessageFormatErr{messageType: "Query"} + } + + dst.String = string(src[:i]) + + return nil +} + +func (src *Query) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('Q') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.String) + 1))) + buf.WriteString(src.String) + buf.WriteByte(0) + return buf.Bytes(), nil +} + +func (src *Query) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + String string + }{ + Type: "Query", + String: src.String, + }) +} diff --git a/ready_for_query.go b/ready_for_query.go new file mode 100644 index 00000000..09005d00 --- /dev/null +++ b/ready_for_query.go @@ -0,0 +1,35 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ReadyForQuery struct { + TxStatus byte +} + +func (*ReadyForQuery) Backend() {} + +func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error { + if len(src) != 1 { + return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} + } + + dst.TxStatus = src[0] + + return nil +} + +func (src *ReadyForQuery) MarshalBinary() ([]byte, error) { + return []byte{'Z', 0, 0, 0, 5, src.TxStatus}, nil +} + +func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + TxStatus string + }{ + Type: "ReadyForQuery", + TxStatus: string(src.TxStatus), + }) +} diff --git a/row_description.go b/row_description.go new file mode 100644 index 00000000..294a6aa9 --- /dev/null +++ b/row_description.go @@ -0,0 +1,101 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +const ( + TextFormat = 0 + BinaryFormat = 1 +) + +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier uint32 + Format int16 +} + +type RowDescription struct { + Fields []FieldDescription +} + +func (*RowDescription) Backend() {} + +func (dst *RowDescription) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + *dst = RowDescription{Fields: make([]FieldDescription, fieldCount)} + + for i := 0; i < fieldCount; i++ { + var fd FieldDescription + bName, err := buf.ReadBytes(0) + if err != nil { + return err + } + fd.Name = string(bName[:len(bName)-1]) + + // Since buf.Next() doesn't return an error if we hit the end of the buffer + // check Len ahead of time + if buf.Len() < 18 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + + fd.TableOID = binary.BigEndian.Uint32(buf.Next(4)) + fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2)) + fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4)) + fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2))) + fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4)) + fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) + + dst.Fields[i] = fd + } + + return nil +} + +func (src *RowDescription) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('T') + buf.Write(bigEndian.Uint32(0)) + + buf.Write(bigEndian.Uint16(uint16(len(src.Fields)))) + + for _, fd := range src.Fields { + buf.WriteString(fd.Name) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint32(fd.TableOID)) + buf.Write(bigEndian.Uint16(fd.TableAttributeNumber)) + buf.Write(bigEndian.Uint32(fd.DataTypeOID)) + buf.Write(bigEndian.Uint16(uint16(fd.DataTypeSize))) + buf.Write(bigEndian.Uint32(fd.TypeModifier)) + buf.Write(bigEndian.Uint16(uint16(fd.Format))) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *RowDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Fields []FieldDescription + }{ + Type: "RowDescription", + Fields: src.Fields, + }) +}