Merge branch 'pgproto3import' into v5-dev
This commit is contained in:
@@ -0,0 +1,9 @@
|
|||||||
|
language: go
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.x
|
||||||
|
- tip
|
||||||
|
|
||||||
|
matrix:
|
||||||
|
allow_failures:
|
||||||
|
- go: tip
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
Copyright (c) 2019 Jack Christensen
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
a copy of this software and associated documentation files (the
|
||||||
|
"Software"), to deal in the Software without restriction, including
|
||||||
|
without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be
|
||||||
|
included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
[](https://godoc.org/github.com/jackc/pgproto3)
|
||||||
|
[](https://travis-ci.org/jackc/pgproto3)
|
||||||
|
|
||||||
|
# pgproto3
|
||||||
|
|
||||||
|
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||||
|
|
||||||
|
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.
|
||||||
|
|
||||||
|
See example/pgfortune for a playful example of a fake PostgreSQL server.
|
||||||
|
|
||||||
|
Extracted from original implementation in https://github.com/jackc/pgx.
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
|
||||||
|
type AuthenticationCleartextPassword struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationCleartextPassword) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationCleartextPassword) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
|
||||||
|
if len(src) != 4 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeCleartextPassword {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationCleartextPassword",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
|
||||||
|
type AuthenticationMD5Password struct {
|
||||||
|
Salt [4]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationMD5Password) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationMD5Password) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
|
||||||
|
if len(src) != 8 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeMD5Password {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(dst.Salt[:], src[4:8])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 12)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
|
||||||
|
dst = append(dst, src.Salt[:]...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Salt [4]byte
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationMD5Password",
|
||||||
|
Salt: src.Salt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
Salt [4]byte
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Salt = msg.Salt
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
|
||||||
|
type AuthenticationOk struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationOk) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationOk) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationOk) Decode(src []byte) error {
|
||||||
|
if len(src) != 4 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeOk {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationOk) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationOk) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationOK",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
|
||||||
|
type AuthenticationSASL struct {
|
||||||
|
AuthMechanisms []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASL) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationSASL) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASL) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASL {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
authMechanisms := src[4:]
|
||||||
|
for len(authMechanisms) > 1 {
|
||||||
|
idx := bytes.IndexByte(authMechanisms, 0)
|
||||||
|
if idx > 0 {
|
||||||
|
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
|
||||||
|
authMechanisms = authMechanisms[idx+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
||||||
|
|
||||||
|
for _, s := range src.AuthMechanisms {
|
||||||
|
dst = append(dst, []byte(s)...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
}
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationSASL) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
AuthMechanisms []string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationSASL",
|
||||||
|
AuthMechanisms: src.AuthMechanisms,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
|
||||||
|
type AuthenticationSASLContinue struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASLContinue) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationSASLContinue) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASLContinue {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = src[4:]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
||||||
|
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationSASLContinue",
|
||||||
|
Data: string(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
|
||||||
|
type AuthenticationSASLFinal struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASLFinal) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationSASLFinal) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASLFinal {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = src[4:]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
||||||
|
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationSASLFinal",
|
||||||
|
Data: string(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,208 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Backend acts as a server for the PostgreSQL wire protocol version 3.
|
||||||
|
type Backend struct {
|
||||||
|
cr ChunkReader
|
||||||
|
w io.Writer
|
||||||
|
|
||||||
|
// Frontend message flyweights
|
||||||
|
bind Bind
|
||||||
|
cancelRequest CancelRequest
|
||||||
|
_close Close
|
||||||
|
copyFail CopyFail
|
||||||
|
copyData CopyData
|
||||||
|
copyDone CopyDone
|
||||||
|
describe Describe
|
||||||
|
execute Execute
|
||||||
|
flush Flush
|
||||||
|
functionCall FunctionCall
|
||||||
|
gssEncRequest GSSEncRequest
|
||||||
|
parse Parse
|
||||||
|
query Query
|
||||||
|
sslRequest SSLRequest
|
||||||
|
startupMessage StartupMessage
|
||||||
|
sync Sync
|
||||||
|
terminate Terminate
|
||||||
|
|
||||||
|
bodyLen int
|
||||||
|
msgType byte
|
||||||
|
partialMsg bool
|
||||||
|
authType uint32
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
|
||||||
|
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewBackend creates a new Backend.
|
||||||
|
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
|
||||||
|
return &Backend{cr: cr, w: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends a message to the frontend.
|
||||||
|
func (b *Backend) Send(msg BackendMessage) error {
|
||||||
|
_, err := b.w.Write(msg.Encode(nil))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
|
||||||
|
// because the initial connection message is "special" and does not include the message type as the first byte. This
|
||||||
|
// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
|
||||||
|
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
||||||
|
buf, err := b.cr.Next(4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||||
|
|
||||||
|
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
|
||||||
|
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err = b.cr.Next(msgSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := binary.BigEndian.Uint32(buf)
|
||||||
|
|
||||||
|
switch code {
|
||||||
|
case ProtocolVersionNumber:
|
||||||
|
err = b.startupMessage.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.startupMessage, nil
|
||||||
|
case sslRequestNumber:
|
||||||
|
err = b.sslRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.sslRequest, nil
|
||||||
|
case cancelRequestCode:
|
||||||
|
err = b.cancelRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.cancelRequest, nil
|
||||||
|
case gssEncReqNumber:
|
||||||
|
err = b.gssEncRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.gssEncRequest, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown startup message code: %d", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
|
||||||
|
func (b *Backend) Receive() (FrontendMessage, error) {
|
||||||
|
if !b.partialMsg {
|
||||||
|
header, err := b.cr.Next(5)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.msgType = header[0]
|
||||||
|
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||||
|
b.partialMsg = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg FrontendMessage
|
||||||
|
switch b.msgType {
|
||||||
|
case 'B':
|
||||||
|
msg = &b.bind
|
||||||
|
case 'C':
|
||||||
|
msg = &b._close
|
||||||
|
case 'D':
|
||||||
|
msg = &b.describe
|
||||||
|
case 'E':
|
||||||
|
msg = &b.execute
|
||||||
|
case 'F':
|
||||||
|
msg = &b.functionCall
|
||||||
|
case 'f':
|
||||||
|
msg = &b.copyFail
|
||||||
|
case 'd':
|
||||||
|
msg = &b.copyData
|
||||||
|
case 'c':
|
||||||
|
msg = &b.copyDone
|
||||||
|
case 'H':
|
||||||
|
msg = &b.flush
|
||||||
|
case 'P':
|
||||||
|
msg = &b.parse
|
||||||
|
case 'p':
|
||||||
|
switch b.authType {
|
||||||
|
case AuthTypeSASL:
|
||||||
|
msg = &SASLInitialResponse{}
|
||||||
|
case AuthTypeSASLContinue:
|
||||||
|
msg = &SASLResponse{}
|
||||||
|
case AuthTypeSASLFinal:
|
||||||
|
msg = &SASLResponse{}
|
||||||
|
case AuthTypeCleartextPassword, AuthTypeMD5Password:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
// to maintain backwards compatability
|
||||||
|
msg = &PasswordMessage{}
|
||||||
|
}
|
||||||
|
case 'Q':
|
||||||
|
msg = &b.query
|
||||||
|
case 'S':
|
||||||
|
msg = &b.sync
|
||||||
|
case 'X':
|
||||||
|
msg = &b.terminate
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgBody, err := b.cr.Next(b.bodyLen)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.partialMsg = false
|
||||||
|
|
||||||
|
err = msg.Decode(msgBody)
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAuthType sets the authentication type in the backend.
|
||||||
|
// Since multiple message types can start with 'p', SetAuthType allows
|
||||||
|
// contextual identification of FrontendMessages. For example, in the
|
||||||
|
// PG message flow documentation for PasswordMessage:
|
||||||
|
//
|
||||||
|
// Byte1('p')
|
||||||
|
//
|
||||||
|
// Identifies the message as a password response. Note that this is also used for
|
||||||
|
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
|
||||||
|
// the context.
|
||||||
|
//
|
||||||
|
// Since the Frontend does not know about the state of a backend, it is important
|
||||||
|
// to call SetAuthType() after an authentication request is received by the Frontend.
|
||||||
|
func (b *Backend) SetAuthType(authType uint32) error {
|
||||||
|
switch authType {
|
||||||
|
case AuthTypeOk,
|
||||||
|
AuthTypeCleartextPassword,
|
||||||
|
AuthTypeMD5Password,
|
||||||
|
AuthTypeSCMCreds,
|
||||||
|
AuthTypeGSS,
|
||||||
|
AuthTypeGSSCont,
|
||||||
|
AuthTypeSSPI,
|
||||||
|
AuthTypeSASL,
|
||||||
|
AuthTypeSASLContinue,
|
||||||
|
AuthTypeSASLFinal:
|
||||||
|
b.authType = authType
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("authType not recognized: %d", authType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackendKeyData struct {
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*BackendKeyData) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *BackendKeyData) Decode(src []byte) error {
|
||||||
|
if len(src) != 8 {
|
||||||
|
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
|
||||||
|
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'K')
|
||||||
|
dst = pgio.AppendUint32(dst, 12)
|
||||||
|
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||||
|
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src BackendKeyData) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}{
|
||||||
|
Type: "BackendKeyData",
|
||||||
|
ProcessID: src.ProcessID,
|
||||||
|
SecretKey: src.SecretKey,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package pgproto3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBackendReceiveInterrupted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push([]byte{'Q', 0, 0, 0, 6})
|
||||||
|
|
||||||
|
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
|
||||||
|
|
||||||
|
msg, err := backend.Receive()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected err")
|
||||||
|
}
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("did not expect msg, but %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.push([]byte{'I', 0})
|
||||||
|
|
||||||
|
msg, err = backend.Receive()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
|
||||||
|
t.Fatalf("unexpected msg: %v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendReceiveUnexpectedEOF(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push([]byte{'Q', 0, 0, 0, 6})
|
||||||
|
|
||||||
|
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
|
||||||
|
|
||||||
|
// Receive regular msg
|
||||||
|
msg, err := backend.Receive()
|
||||||
|
assert.Nil(t, msg)
|
||||||
|
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
||||||
|
|
||||||
|
// Receive StartupMessage msg
|
||||||
|
dst := []byte{}
|
||||||
|
dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read
|
||||||
|
dst = pgio.AppendUint32(dst, 1) // only send 1 byte
|
||||||
|
server.push(dst)
|
||||||
|
|
||||||
|
msg, err = backend.ReceiveStartupMessage()
|
||||||
|
assert.Nil(t, msg)
|
||||||
|
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartupMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("valid StartupMessage", func(t *testing.T) {
|
||||||
|
want := &pgproto3.StartupMessage{
|
||||||
|
ProtocolVersion: pgproto3.ProtocolVersionNumber,
|
||||||
|
Parameters: map[string]string{
|
||||||
|
"username": "tester",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dst := []byte{}
|
||||||
|
dst = want.Encode(dst)
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push(dst)
|
||||||
|
|
||||||
|
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
|
||||||
|
|
||||||
|
msg, err := backend.ReceiveStartupMessage()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, want, msg)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid packet length", func(t *testing.T) {
|
||||||
|
wantErr := "invalid length of startup packet"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
packetLen uint32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "large packet length",
|
||||||
|
// Since the StartupMessage contains the "Length of message contents
|
||||||
|
// in bytes, including self", the max startup packet length is actually
|
||||||
|
// 10000+4. Therefore, let's go past the limit with 10005
|
||||||
|
packetLen: 10005,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "short packet length",
|
||||||
|
packetLen: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &interruptReader{}
|
||||||
|
dst := []byte{}
|
||||||
|
dst = pgio.AppendUint32(dst, tt.packetLen)
|
||||||
|
dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
|
||||||
|
server.push(dst)
|
||||||
|
|
||||||
|
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
|
||||||
|
|
||||||
|
msg, err := backend.ReceiveStartupMessage()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, msg)
|
||||||
|
require.Contains(t, err.Error(), wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BigEndianBuf [8]byte
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Int16(n int16) []byte {
|
||||||
|
buf := b[0:2]
|
||||||
|
binary.BigEndian.PutUint16(buf, uint16(n))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Uint16(n uint16) []byte {
|
||||||
|
buf := b[0:2]
|
||||||
|
binary.BigEndian.PutUint16(buf, n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Int32(n int32) []byte {
|
||||||
|
buf := b[0:4]
|
||||||
|
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Uint32(n uint32) []byte {
|
||||||
|
buf := b[0:4]
|
||||||
|
binary.BigEndian.PutUint32(buf, n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Int64(n int64) []byte {
|
||||||
|
buf := b[0:8]
|
||||||
|
binary.BigEndian.PutUint64(buf, uint64(n))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
@@ -0,0 +1,216 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bind struct {
|
||||||
|
DestinationPortal string
|
||||||
|
PreparedStatement string
|
||||||
|
ParameterFormatCodes []int16
|
||||||
|
Parameters [][]byte
|
||||||
|
ResultFormatCodes []int16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Bind) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Bind) Decode(src []byte) error {
|
||||||
|
*dst = Bind{}
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
dst.DestinationPortal = string(src[:idx])
|
||||||
|
rp := idx + 1
|
||||||
|
|
||||||
|
idx = bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
dst.PreparedStatement = string(src[rp : rp+idx])
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
if len(src[rp:]) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
if parameterFormatCodeCount > 0 {
|
||||||
|
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
|
||||||
|
|
||||||
|
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
for i := 0; i < parameterFormatCodeCount; i++ {
|
||||||
|
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
if parameterCount > 0 {
|
||||||
|
dst.Parameters = make([][]byte, parameterCount)
|
||||||
|
|
||||||
|
for i := 0; i < parameterCount; i++ {
|
||||||
|
if len(src[rp:]) < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
// null
|
||||||
|
if msgSize == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) < msgSize {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Parameters[i] = src[rp : rp+msgSize]
|
||||||
|
rp += msgSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
|
||||||
|
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
for i := 0; i < resultFormatCodeCount; i++ {
|
||||||
|
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Bind) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'B')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.DestinationPortal...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, src.PreparedStatement...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
||||||
|
for _, fc := range src.ParameterFormatCodes {
|
||||||
|
dst = pgio.AppendInt16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
||||||
|
for _, p := range src.Parameters {
|
||||||
|
if p == nil {
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(p)))
|
||||||
|
dst = append(dst, p...)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
||||||
|
for _, fc := range src.ResultFormatCodes {
|
||||||
|
dst = pgio.AppendInt16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Bind) MarshalJSON() ([]byte, error) {
|
||||||
|
formattedParameters := make([]map[string]string, len(src.Parameters))
|
||||||
|
for i, p := range src.Parameters {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
textFormat := true
|
||||||
|
if len(src.ParameterFormatCodes) == 1 {
|
||||||
|
textFormat = src.ParameterFormatCodes[0] == 0
|
||||||
|
} else if len(src.ParameterFormatCodes) > 1 {
|
||||||
|
textFormat = src.ParameterFormatCodes[i] == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if textFormat {
|
||||||
|
formattedParameters[i] = map[string]string{"text": string(p)}
|
||||||
|
} else {
|
||||||
|
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
DestinationPortal string
|
||||||
|
PreparedStatement string
|
||||||
|
ParameterFormatCodes []int16
|
||||||
|
Parameters []map[string]string
|
||||||
|
ResultFormatCodes []int16
|
||||||
|
}{
|
||||||
|
Type: "Bind",
|
||||||
|
DestinationPortal: src.DestinationPortal,
|
||||||
|
PreparedStatement: src.PreparedStatement,
|
||||||
|
ParameterFormatCodes: src.ParameterFormatCodes,
|
||||||
|
Parameters: formattedParameters,
|
||||||
|
ResultFormatCodes: src.ResultFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *Bind) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
DestinationPortal string
|
||||||
|
PreparedStatement string
|
||||||
|
ParameterFormatCodes []int16
|
||||||
|
Parameters []map[string]string
|
||||||
|
ResultFormatCodes []int16
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(data, &msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.DestinationPortal = msg.DestinationPortal
|
||||||
|
dst.PreparedStatement = msg.PreparedStatement
|
||||||
|
dst.ParameterFormatCodes = msg.ParameterFormatCodes
|
||||||
|
dst.Parameters = make([][]byte, len(msg.Parameters))
|
||||||
|
dst.ResultFormatCodes = msg.ResultFormatCodes
|
||||||
|
for n, parameter := range msg.Parameters {
|
||||||
|
dst.Parameters[n], err = getValueFromJSON(parameter)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot get param %d: %w", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BindComplete struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*BindComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *BindComplete) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *BindComplete) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, '2', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src BindComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "BindComplete",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const cancelRequestCode = 80877102
|
||||||
|
|
||||||
|
type CancelRequest struct {
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CancelRequest) Frontend() {}
|
||||||
|
|
||||||
|
func (dst *CancelRequest) Decode(src []byte) error {
|
||||||
|
if len(src) != 12 {
|
||||||
|
return errors.New("bad cancel request size")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCode := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if requestCode != cancelRequestCode {
|
||||||
|
return errors.New("bad cancel request code")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
|
||||||
|
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
|
func (src *CancelRequest) Encode(dst []byte) []byte {
|
||||||
|
dst = pgio.AppendInt32(dst, 16)
|
||||||
|
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||||
|
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||||
|
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CancelRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}{
|
||||||
|
Type: "CancelRequest",
|
||||||
|
ProcessID: src.ProcessID,
|
||||||
|
SecretKey: src.SecretKey,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/jackc/chunkreader/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package.
|
||||||
|
type ChunkReader interface {
|
||||||
|
// Next returns buf filled with the next n bytes. If an error (including a partial read) occurs,
|
||||||
|
// buf must be nil. Next must preserve any partially read data. Next must not reuse buf.
|
||||||
|
Next(n int) (buf []byte, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChunkReader creates and returns a new default ChunkReader.
|
||||||
|
func NewChunkReader(r io.Reader) ChunkReader {
|
||||||
|
return chunkreader.New(r)
|
||||||
|
}
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Close struct {
|
||||||
|
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Close) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Close) Decode(src []byte) error {
|
||||||
|
if len(src) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Close"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ObjectType = src[0]
|
||||||
|
rp := 1
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx != len(src[rp:])-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Close"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Name = string(src[rp : len(src)-1])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Close) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'C')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.ObjectType)
|
||||||
|
dst = append(dst, src.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Close) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ObjectType string
|
||||||
|
Name string
|
||||||
|
}{
|
||||||
|
Type: "Close",
|
||||||
|
ObjectType: string(src.ObjectType),
|
||||||
|
Name: src.Name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *Close) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
ObjectType string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.ObjectType) != 1 {
|
||||||
|
return errors.New("invalid length for Close.ObjectType")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ObjectType = byte(msg.ObjectType[0])
|
||||||
|
dst.Name = msg.Name
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CloseComplete struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CloseComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CloseComplete) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CloseComplete) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, '3', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CloseComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "CloseComplete",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CommandComplete struct {
|
||||||
|
CommandTag []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CommandComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CommandComplete) Decode(src []byte) error {
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx != len(src)-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CommandComplete"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.CommandTag = src[:idx]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CommandComplete) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'C')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.CommandTag...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CommandComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
CommandTag string
|
||||||
|
}{
|
||||||
|
Type: "CommandComplete",
|
||||||
|
CommandTag: string(src.CommandTag),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
CommandTag string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.CommandTag = []byte(msg.CommandTag)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyBothResponse struct {
|
||||||
|
OverallFormat byte
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CopyBothResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
if buf.Len() < 3 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
overallFormat := buf.Next(1)[0]
|
||||||
|
|
||||||
|
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||||
|
if buf.Len() != columnCount*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
columnFormatCodes := make([]uint16, columnCount)
|
||||||
|
for i := 0; i < columnCount; i++ {
|
||||||
|
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'W')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}{
|
||||||
|
Type: "CopyBothResponse",
|
||||||
|
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
OverallFormat string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.OverallFormat) != 1 {
|
||||||
|
return errors.New("invalid length for CopyBothResponse.OverallFormat")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.OverallFormat = msg.OverallFormat[0]
|
||||||
|
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyData struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CopyData) Backend() {}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CopyData) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyData) Decode(src []byte) error {
|
||||||
|
dst.Data = src
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyData) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'd')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyData) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "CopyData",
|
||||||
|
Data: hex.EncodeToString(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CopyData) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyDone struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CopyDone) Backend() {}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CopyDone) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyDone) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyDone) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'c', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyDone) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "CopyDone",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyFail struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CopyFail) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyFail) Decode(src []byte) error {
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx != len(src)-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyFail"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Message = string(src[:idx])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyFail) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'f')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.Message...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyFail) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Message string
|
||||||
|
}{
|
||||||
|
Type: "CopyFail",
|
||||||
|
Message: src.Message,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyInResponse struct {
|
||||||
|
OverallFormat byte
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CopyInResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyInResponse) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
if buf.Len() < 3 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
overallFormat := buf.Next(1)[0]
|
||||||
|
|
||||||
|
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||||
|
if buf.Len() != columnCount*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
columnFormatCodes := make([]uint16, columnCount)
|
||||||
|
for i := 0; i < columnCount; i++ {
|
||||||
|
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyInResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'G')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.OverallFormat)
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyInResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}{
|
||||||
|
Type: "CopyInResponse",
|
||||||
|
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CopyInResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
OverallFormat string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.OverallFormat) != 1 {
|
||||||
|
return errors.New("invalid length for CopyInResponse.OverallFormat")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.OverallFormat = msg.OverallFormat[0]
|
||||||
|
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyOutResponse struct {
|
||||||
|
OverallFormat byte
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*CopyOutResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyOutResponse) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
if buf.Len() < 3 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
overallFormat := buf.Next(1)[0]
|
||||||
|
|
||||||
|
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||||
|
if buf.Len() != columnCount*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
columnFormatCodes := make([]uint16, columnCount)
|
||||||
|
for i := 0; i < columnCount; i++ {
|
||||||
|
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'H')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.OverallFormat)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyOutResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}{
|
||||||
|
Type: "CopyOutResponse",
|
||||||
|
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
OverallFormat string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.OverallFormat) != 1 {
|
||||||
|
return errors.New("invalid length for CopyOutResponse.OverallFormat")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.OverallFormat = msg.OverallFormat[0]
|
||||||
|
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DataRow struct {
|
||||||
|
Values [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*DataRow) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *DataRow) Decode(src []byte) error {
|
||||||
|
if len(src) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||||
|
}
|
||||||
|
rp := 0
|
||||||
|
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
// If the capacity of the values slice is too small OR substantially too
|
||||||
|
// large reallocate. This is too avoid one row with many columns from
|
||||||
|
// permanently allocating memory.
|
||||||
|
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
|
||||||
|
newCap := 32
|
||||||
|
if newCap < fieldCount {
|
||||||
|
newCap = fieldCount
|
||||||
|
}
|
||||||
|
dst.Values = make([][]byte, fieldCount, newCap)
|
||||||
|
} else {
|
||||||
|
dst.Values = dst.Values[:fieldCount]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < fieldCount; i++ {
|
||||||
|
if len(src[rp:]) < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
// null
|
||||||
|
if msgSize == -1 {
|
||||||
|
dst.Values[i] = nil
|
||||||
|
} else {
|
||||||
|
if len(src[rp:]) < msgSize {
|
||||||
|
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Values[i] = src[rp : rp+msgSize : rp+msgSize]
|
||||||
|
rp += msgSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *DataRow) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'D')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
||||||
|
for _, v := range src.Values {
|
||||||
|
if v == nil {
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(v)))
|
||||||
|
dst = append(dst, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src DataRow) MarshalJSON() ([]byte, error) {
|
||||||
|
formattedValues := make([]map[string]string, len(src.Values))
|
||||||
|
for i, v := range src.Values {
|
||||||
|
if v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var hasNonPrintable bool
|
||||||
|
for _, b := range v {
|
||||||
|
if b < 32 {
|
||||||
|
hasNonPrintable = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasNonPrintable {
|
||||||
|
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
|
||||||
|
} else {
|
||||||
|
formattedValues[i] = map[string]string{"text": string(v)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Values []map[string]string
|
||||||
|
}{
|
||||||
|
Type: "DataRow",
|
||||||
|
Values: formattedValues,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *DataRow) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Values []map[string]string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Values = make([][]byte, len(msg.Values))
|
||||||
|
for n, parameter := range msg.Values {
|
||||||
|
var err error
|
||||||
|
dst.Values[n], err = getValueFromJSON(parameter)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Describe struct {
|
||||||
|
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Describe) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Describe) Decode(src []byte) error {
|
||||||
|
if len(src) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Describe"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ObjectType = src[0]
|
||||||
|
rp := 1
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx != len(src[rp:])-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Describe"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Name = string(src[rp : len(src)-1])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Describe) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'D')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.ObjectType)
|
||||||
|
dst = append(dst, src.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Describe) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ObjectType string
|
||||||
|
Name string
|
||||||
|
}{
|
||||||
|
Type: "Describe",
|
||||||
|
ObjectType: string(src.ObjectType),
|
||||||
|
Name: src.Name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *Describe) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
ObjectType string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(msg.ObjectType) != 1 {
|
||||||
|
return errors.New("invalid length for Describe.ObjectType")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ObjectType = byte(msg.ObjectType[0])
|
||||||
|
dst.Name = msg.Name
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages.
|
||||||
|
package pgproto3
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EmptyQueryResponse struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*EmptyQueryResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'I', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "EmptyQueryResponse",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,334 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Severity string
|
||||||
|
SeverityUnlocalized string // only in 9.6 and greater
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
Detail string
|
||||||
|
Hint string
|
||||||
|
Position int32
|
||||||
|
InternalPosition int32
|
||||||
|
InternalQuery string
|
||||||
|
Where string
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
DataTypeName string
|
||||||
|
ConstraintName string
|
||||||
|
File string
|
||||||
|
Line int32
|
||||||
|
Routine string
|
||||||
|
|
||||||
|
UnknownFields map[byte]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*ErrorResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *ErrorResponse) Decode(src []byte) error {
|
||||||
|
*dst = ErrorResponse{}
|
||||||
|
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
for {
|
||||||
|
k, err := buf.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if k == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
vb, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
v := string(vb[:len(vb)-1])
|
||||||
|
|
||||||
|
switch k {
|
||||||
|
case 'S':
|
||||||
|
dst.Severity = v
|
||||||
|
case 'V':
|
||||||
|
dst.SeverityUnlocalized = v
|
||||||
|
case 'C':
|
||||||
|
dst.Code = v
|
||||||
|
case 'M':
|
||||||
|
dst.Message = v
|
||||||
|
case 'D':
|
||||||
|
dst.Detail = v
|
||||||
|
case 'H':
|
||||||
|
dst.Hint = v
|
||||||
|
case 'P':
|
||||||
|
s := v
|
||||||
|
n, _ := strconv.ParseInt(s, 10, 32)
|
||||||
|
dst.Position = int32(n)
|
||||||
|
case 'p':
|
||||||
|
s := v
|
||||||
|
n, _ := strconv.ParseInt(s, 10, 32)
|
||||||
|
dst.InternalPosition = int32(n)
|
||||||
|
case 'q':
|
||||||
|
dst.InternalQuery = v
|
||||||
|
case 'W':
|
||||||
|
dst.Where = v
|
||||||
|
case 's':
|
||||||
|
dst.SchemaName = v
|
||||||
|
case 't':
|
||||||
|
dst.TableName = v
|
||||||
|
case 'c':
|
||||||
|
dst.ColumnName = v
|
||||||
|
case 'd':
|
||||||
|
dst.DataTypeName = v
|
||||||
|
case 'n':
|
||||||
|
dst.ConstraintName = v
|
||||||
|
case 'F':
|
||||||
|
dst.File = v
|
||||||
|
case 'L':
|
||||||
|
s := v
|
||||||
|
n, _ := strconv.ParseInt(s, 10, 32)
|
||||||
|
dst.Line = int32(n)
|
||||||
|
case 'R':
|
||||||
|
dst.Routine = v
|
||||||
|
|
||||||
|
default:
|
||||||
|
if dst.UnknownFields == nil {
|
||||||
|
dst.UnknownFields = make(map[byte]string)
|
||||||
|
}
|
||||||
|
dst.UnknownFields[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *ErrorResponse) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, src.marshalBinary('E')...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
||||||
|
var bigEndian BigEndianBuf
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
|
||||||
|
buf.WriteByte(typeByte)
|
||||||
|
buf.Write(bigEndian.Uint32(0))
|
||||||
|
|
||||||
|
if src.Severity != "" {
|
||||||
|
buf.WriteByte('S')
|
||||||
|
buf.WriteString(src.Severity)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.SeverityUnlocalized != "" {
|
||||||
|
buf.WriteByte('V')
|
||||||
|
buf.WriteString(src.SeverityUnlocalized)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Code != "" {
|
||||||
|
buf.WriteByte('C')
|
||||||
|
buf.WriteString(src.Code)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Message != "" {
|
||||||
|
buf.WriteByte('M')
|
||||||
|
buf.WriteString(src.Message)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Detail != "" {
|
||||||
|
buf.WriteByte('D')
|
||||||
|
buf.WriteString(src.Detail)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Hint != "" {
|
||||||
|
buf.WriteByte('H')
|
||||||
|
buf.WriteString(src.Hint)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Position != 0 {
|
||||||
|
buf.WriteByte('P')
|
||||||
|
buf.WriteString(strconv.Itoa(int(src.Position)))
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.InternalPosition != 0 {
|
||||||
|
buf.WriteByte('p')
|
||||||
|
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.InternalQuery != "" {
|
||||||
|
buf.WriteByte('q')
|
||||||
|
buf.WriteString(src.InternalQuery)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Where != "" {
|
||||||
|
buf.WriteByte('W')
|
||||||
|
buf.WriteString(src.Where)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.SchemaName != "" {
|
||||||
|
buf.WriteByte('s')
|
||||||
|
buf.WriteString(src.SchemaName)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.TableName != "" {
|
||||||
|
buf.WriteByte('t')
|
||||||
|
buf.WriteString(src.TableName)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.ColumnName != "" {
|
||||||
|
buf.WriteByte('c')
|
||||||
|
buf.WriteString(src.ColumnName)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.DataTypeName != "" {
|
||||||
|
buf.WriteByte('d')
|
||||||
|
buf.WriteString(src.DataTypeName)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.ConstraintName != "" {
|
||||||
|
buf.WriteByte('n')
|
||||||
|
buf.WriteString(src.ConstraintName)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.File != "" {
|
||||||
|
buf.WriteByte('F')
|
||||||
|
buf.WriteString(src.File)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Line != 0 {
|
||||||
|
buf.WriteByte('L')
|
||||||
|
buf.WriteString(strconv.Itoa(int(src.Line)))
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Routine != "" {
|
||||||
|
buf.WriteByte('R')
|
||||||
|
buf.WriteString(src.Routine)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range src.UnknownFields {
|
||||||
|
buf.WriteByte(k)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
buf.WriteString(v)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteByte(0)
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src ErrorResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Severity string
|
||||||
|
SeverityUnlocalized string // only in 9.6 and greater
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
Detail string
|
||||||
|
Hint string
|
||||||
|
Position int32
|
||||||
|
InternalPosition int32
|
||||||
|
InternalQuery string
|
||||||
|
Where string
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
DataTypeName string
|
||||||
|
ConstraintName string
|
||||||
|
File string
|
||||||
|
Line int32
|
||||||
|
Routine string
|
||||||
|
|
||||||
|
UnknownFields map[byte]string
|
||||||
|
}{
|
||||||
|
Type: "ErrorResponse",
|
||||||
|
Severity: src.Severity,
|
||||||
|
SeverityUnlocalized: src.SeverityUnlocalized,
|
||||||
|
Code: src.Code,
|
||||||
|
Message: src.Message,
|
||||||
|
Detail: src.Detail,
|
||||||
|
Hint: src.Hint,
|
||||||
|
Position: src.Position,
|
||||||
|
InternalPosition: src.InternalPosition,
|
||||||
|
InternalQuery: src.InternalQuery,
|
||||||
|
Where: src.Where,
|
||||||
|
SchemaName: src.SchemaName,
|
||||||
|
TableName: src.TableName,
|
||||||
|
ColumnName: src.ColumnName,
|
||||||
|
DataTypeName: src.DataTypeName,
|
||||||
|
ConstraintName: src.ConstraintName,
|
||||||
|
File: src.File,
|
||||||
|
Line: src.Line,
|
||||||
|
Routine: src.Routine,
|
||||||
|
UnknownFields: src.UnknownFields,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *ErrorResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
Severity string
|
||||||
|
SeverityUnlocalized string // only in 9.6 and greater
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
Detail string
|
||||||
|
Hint string
|
||||||
|
Position int32
|
||||||
|
InternalPosition int32
|
||||||
|
InternalQuery string
|
||||||
|
Where string
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
DataTypeName string
|
||||||
|
ConstraintName string
|
||||||
|
File string
|
||||||
|
Line int32
|
||||||
|
Routine string
|
||||||
|
|
||||||
|
UnknownFields map[byte]string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Severity = msg.Severity
|
||||||
|
dst.SeverityUnlocalized = msg.SeverityUnlocalized
|
||||||
|
dst.Code = msg.Code
|
||||||
|
dst.Message = msg.Message
|
||||||
|
dst.Detail = msg.Detail
|
||||||
|
dst.Hint = msg.Hint
|
||||||
|
dst.Position = msg.Position
|
||||||
|
dst.InternalPosition = msg.InternalPosition
|
||||||
|
dst.InternalQuery = msg.InternalQuery
|
||||||
|
dst.Where = msg.Where
|
||||||
|
dst.SchemaName = msg.SchemaName
|
||||||
|
dst.TableName = msg.TableName
|
||||||
|
dst.ColumnName = msg.ColumnName
|
||||||
|
dst.DataTypeName = msg.DataTypeName
|
||||||
|
dst.ConstraintName = msg.ConstraintName
|
||||||
|
dst.File = msg.File
|
||||||
|
dst.Line = msg.Line
|
||||||
|
dst.Routine = msg.Routine
|
||||||
|
|
||||||
|
dst.UnknownFields = msg.UnknownFields
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
# pgfortune
|
||||||
|
|
||||||
|
pgfortune is a mock PostgreSQL server that responds to every query with a fortune.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Install `fortune` and `cowsay`. They should be available in any Unix package manager (apt, yum, brew, etc.)
|
||||||
|
|
||||||
|
```
|
||||||
|
go get -u github.com/jackc/pgproto3/example/pgfortune
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
$ pgfortune
|
||||||
|
```
|
||||||
|
|
||||||
|
By default pgfortune listens on 127.0.0.1:15432 and responds to queries with `fortune | cowsay -f elephant`. These are
|
||||||
|
configurable with the `listen` and `response-command` arguments respectively.
|
||||||
|
|
||||||
|
While `pgfortune` is running connect to it with `psql`.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ psql -h 127.0.0.1 -p 15432
|
||||||
|
Timing is on.
|
||||||
|
Null display is "∅".
|
||||||
|
Line style is unicode.
|
||||||
|
psql (11.5, server 0.0.0)
|
||||||
|
Type "help" for help.
|
||||||
|
|
||||||
|
jack@127.0.0.1:15432 jack=# select foo;
|
||||||
|
fortune
|
||||||
|
─────────────────────────────────────────────
|
||||||
|
_________________________________________ ↵
|
||||||
|
/ Ships are safe in harbor, but they were \↵
|
||||||
|
\ never meant to stay there. /↵
|
||||||
|
----------------------------------------- ↵
|
||||||
|
\ /\ ___ /\ ↵
|
||||||
|
\ // \/ \/ \\ ↵
|
||||||
|
(( O O )) ↵
|
||||||
|
\\ / \ // ↵
|
||||||
|
\/ | | \/ ↵
|
||||||
|
| | | | ↵
|
||||||
|
| | | | ↵
|
||||||
|
| o | ↵
|
||||||
|
| | | | ↵
|
||||||
|
|m| |m| ↵
|
||||||
|
|
||||||
|
(1 row)
|
||||||
|
|
||||||
|
Time: 28.161 ms
|
||||||
|
```
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
)
|
||||||
|
|
||||||
|
var options struct {
|
||||||
|
listenAddress string
|
||||||
|
responseCommand string
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Usage = func() {
|
||||||
|
fmt.Fprintf(os.Stderr, "usage: %s [options]\n", os.Args[0])
|
||||||
|
flag.PrintDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
flag.StringVar(&options.listenAddress, "listen", "127.0.0.1:15432", "Listen address")
|
||||||
|
flag.StringVar(&options.responseCommand, "response-command", "fortune | cowsay -f elephant", "Command to execute to generate query response")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", options.listenAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
log.Println("Listening on", ln.Addr())
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
log.Println("Accepted connection from", conn.RemoteAddr())
|
||||||
|
|
||||||
|
b := NewPgFortuneBackend(conn, func() ([]byte, error) {
|
||||||
|
return exec.Command("sh", "-c", options.responseCommand).CombinedOutput()
|
||||||
|
})
|
||||||
|
go func() {
|
||||||
|
err := b.Run()
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
|
log.Println("Closed connection from", conn.RemoteAddr())
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PgFortuneBackend struct {
|
||||||
|
backend *pgproto3.Backend
|
||||||
|
conn net.Conn
|
||||||
|
responder func() ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPgFortuneBackend(conn net.Conn, responder func() ([]byte, error)) *PgFortuneBackend {
|
||||||
|
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)
|
||||||
|
|
||||||
|
connHandler := &PgFortuneBackend{
|
||||||
|
backend: backend,
|
||||||
|
conn: conn,
|
||||||
|
responder: responder,
|
||||||
|
}
|
||||||
|
|
||||||
|
return connHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PgFortuneBackend) Run() error {
|
||||||
|
defer p.Close()
|
||||||
|
|
||||||
|
err := p.handleStartup()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg, err := p.backend.Receive()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error receiving message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.(type) {
|
||||||
|
case *pgproto3.Query:
|
||||||
|
response, err := p.responder()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating query response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
||||||
|
{
|
||||||
|
Name: []byte("fortune"),
|
||||||
|
TableOID: 0,
|
||||||
|
TableAttributeNumber: 0,
|
||||||
|
DataTypeOID: 25,
|
||||||
|
DataTypeSize: -1,
|
||||||
|
TypeModifier: -1,
|
||||||
|
Format: 0,
|
||||||
|
},
|
||||||
|
}}).Encode(nil)
|
||||||
|
buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)
|
||||||
|
buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
|
||||||
|
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
|
||||||
|
_, err = p.conn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error writing query response: %w", err)
|
||||||
|
}
|
||||||
|
case *pgproto3.Terminate:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("received message other than Query from client: %#v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PgFortuneBackend) handleStartup() error {
|
||||||
|
startupMessage, err := p.backend.ReceiveStartupMessage()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error receiving startup message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch startupMessage.(type) {
|
||||||
|
case *pgproto3.StartupMessage:
|
||||||
|
buf := (&pgproto3.AuthenticationOk{}).Encode(nil)
|
||||||
|
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
|
||||||
|
_, err = p.conn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error sending ready for query: %w", err)
|
||||||
|
}
|
||||||
|
case *pgproto3.SSLRequest:
|
||||||
|
_, err = p.conn.Write([]byte("N"))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error sending deny SSL request: %w", err)
|
||||||
|
}
|
||||||
|
return p.handleStartup()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown startup message: %#v", startupMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PgFortuneBackend) Close() error {
|
||||||
|
return p.conn.Close()
|
||||||
|
}
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Execute struct {
|
||||||
|
Portal string
|
||||||
|
MaxRows uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Execute) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Execute) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
b, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Portal = string(b[:len(b)-1])
|
||||||
|
|
||||||
|
if buf.Len() < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Execute"}
|
||||||
|
}
|
||||||
|
dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Execute) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'E')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.Portal...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint32(dst, src.MaxRows)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Execute) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Portal string
|
||||||
|
MaxRows uint32
|
||||||
|
}{
|
||||||
|
Type: "Execute",
|
||||||
|
Portal: src.Portal,
|
||||||
|
MaxRows: src.MaxRows,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Flush struct{}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Flush) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Flush) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Flush) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'H', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Flush) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "Flush",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,201 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Frontend acts as a client for the PostgreSQL wire protocol version 3.
|
||||||
|
type Frontend struct {
|
||||||
|
cr ChunkReader
|
||||||
|
w io.Writer
|
||||||
|
|
||||||
|
// Backend message flyweights
|
||||||
|
authenticationOk AuthenticationOk
|
||||||
|
authenticationCleartextPassword AuthenticationCleartextPassword
|
||||||
|
authenticationMD5Password AuthenticationMD5Password
|
||||||
|
authenticationSASL AuthenticationSASL
|
||||||
|
authenticationSASLContinue AuthenticationSASLContinue
|
||||||
|
authenticationSASLFinal AuthenticationSASLFinal
|
||||||
|
backendKeyData BackendKeyData
|
||||||
|
bindComplete BindComplete
|
||||||
|
closeComplete CloseComplete
|
||||||
|
commandComplete CommandComplete
|
||||||
|
copyBothResponse CopyBothResponse
|
||||||
|
copyData CopyData
|
||||||
|
copyInResponse CopyInResponse
|
||||||
|
copyOutResponse CopyOutResponse
|
||||||
|
copyDone CopyDone
|
||||||
|
dataRow DataRow
|
||||||
|
emptyQueryResponse EmptyQueryResponse
|
||||||
|
errorResponse ErrorResponse
|
||||||
|
functionCallResponse FunctionCallResponse
|
||||||
|
noData NoData
|
||||||
|
noticeResponse NoticeResponse
|
||||||
|
notificationResponse NotificationResponse
|
||||||
|
parameterDescription ParameterDescription
|
||||||
|
parameterStatus ParameterStatus
|
||||||
|
parseComplete ParseComplete
|
||||||
|
readyForQuery ReadyForQuery
|
||||||
|
rowDescription RowDescription
|
||||||
|
portalSuspended PortalSuspended
|
||||||
|
|
||||||
|
bodyLen int
|
||||||
|
msgType byte
|
||||||
|
partialMsg bool
|
||||||
|
authType uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFrontend creates a new Frontend.
|
||||||
|
func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
|
||||||
|
return &Frontend{cr: cr, w: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends a message to the backend.
|
||||||
|
func (f *Frontend) Send(msg FrontendMessage) error {
|
||||||
|
_, err := f.w.Write(msg.Encode(nil))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func translateEOFtoErrUnexpectedEOF(err error) error {
|
||||||
|
if err == io.EOF {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
|
||||||
|
func (f *Frontend) Receive() (BackendMessage, error) {
|
||||||
|
if !f.partialMsg {
|
||||||
|
header, err := f.cr.Next(5)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.msgType = header[0]
|
||||||
|
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||||
|
f.partialMsg = true
|
||||||
|
}
|
||||||
|
|
||||||
|
msgBody, err := f.cr.Next(f.bodyLen)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.partialMsg = false
|
||||||
|
|
||||||
|
var msg BackendMessage
|
||||||
|
switch f.msgType {
|
||||||
|
case '1':
|
||||||
|
msg = &f.parseComplete
|
||||||
|
case '2':
|
||||||
|
msg = &f.bindComplete
|
||||||
|
case '3':
|
||||||
|
msg = &f.closeComplete
|
||||||
|
case 'A':
|
||||||
|
msg = &f.notificationResponse
|
||||||
|
case 'c':
|
||||||
|
msg = &f.copyDone
|
||||||
|
case 'C':
|
||||||
|
msg = &f.commandComplete
|
||||||
|
case 'd':
|
||||||
|
msg = &f.copyData
|
||||||
|
case 'D':
|
||||||
|
msg = &f.dataRow
|
||||||
|
case 'E':
|
||||||
|
msg = &f.errorResponse
|
||||||
|
case 'G':
|
||||||
|
msg = &f.copyInResponse
|
||||||
|
case 'H':
|
||||||
|
msg = &f.copyOutResponse
|
||||||
|
case 'I':
|
||||||
|
msg = &f.emptyQueryResponse
|
||||||
|
case 'K':
|
||||||
|
msg = &f.backendKeyData
|
||||||
|
case 'n':
|
||||||
|
msg = &f.noData
|
||||||
|
case 'N':
|
||||||
|
msg = &f.noticeResponse
|
||||||
|
case 'R':
|
||||||
|
var err error
|
||||||
|
msg, err = f.findAuthenticationMessageType(msgBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case 's':
|
||||||
|
msg = &f.portalSuspended
|
||||||
|
case 'S':
|
||||||
|
msg = &f.parameterStatus
|
||||||
|
case 't':
|
||||||
|
msg = &f.parameterDescription
|
||||||
|
case 'T':
|
||||||
|
msg = &f.rowDescription
|
||||||
|
case 'V':
|
||||||
|
msg = &f.functionCallResponse
|
||||||
|
case 'W':
|
||||||
|
msg = &f.copyBothResponse
|
||||||
|
case 'Z':
|
||||||
|
msg = &f.readyForQuery
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = msg.Decode(msgBody)
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authentication message type constants.
|
||||||
|
// See src/include/libpq/pqcomm.h for all
|
||||||
|
// constants.
|
||||||
|
const (
|
||||||
|
AuthTypeOk = 0
|
||||||
|
AuthTypeCleartextPassword = 3
|
||||||
|
AuthTypeMD5Password = 5
|
||||||
|
AuthTypeSCMCreds = 6
|
||||||
|
AuthTypeGSS = 7
|
||||||
|
AuthTypeGSSCont = 8
|
||||||
|
AuthTypeSSPI = 9
|
||||||
|
AuthTypeSASL = 10
|
||||||
|
AuthTypeSASLContinue = 11
|
||||||
|
AuthTypeSASLFinal = 12
|
||||||
|
)
|
||||||
|
|
||||||
|
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return nil, errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
f.authType = binary.BigEndian.Uint32(src[:4])
|
||||||
|
|
||||||
|
switch f.authType {
|
||||||
|
case AuthTypeOk:
|
||||||
|
return &f.authenticationOk, nil
|
||||||
|
case AuthTypeCleartextPassword:
|
||||||
|
return &f.authenticationCleartextPassword, nil
|
||||||
|
case AuthTypeMD5Password:
|
||||||
|
return &f.authenticationMD5Password, nil
|
||||||
|
case AuthTypeSCMCreds:
|
||||||
|
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
|
||||||
|
case AuthTypeGSS:
|
||||||
|
return nil, errors.New("AuthTypeGSS is unimplemented")
|
||||||
|
case AuthTypeGSSCont:
|
||||||
|
return nil, errors.New("AuthTypeGSSCont is unimplemented")
|
||||||
|
case AuthTypeSSPI:
|
||||||
|
return nil, errors.New("AuthTypeSSPI is unimplemented")
|
||||||
|
case AuthTypeSASL:
|
||||||
|
return &f.authenticationSASL, nil
|
||||||
|
case AuthTypeSASLContinue:
|
||||||
|
return &f.authenticationSASLContinue, nil
|
||||||
|
case AuthTypeSASLFinal:
|
||||||
|
return &f.authenticationSASLFinal, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthType returns the authType used in the current state of the frontend.
|
||||||
|
// See SetAuthType for more information.
|
||||||
|
func (f *Frontend) GetAuthType() uint32 {
|
||||||
|
return f.authType
|
||||||
|
}
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
package pgproto3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type interruptReader struct {
|
||||||
|
chunks [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ir *interruptReader) Read(p []byte) (n int, err error) {
|
||||||
|
if len(ir.chunks) == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
n = copy(p, ir.chunks[0])
|
||||||
|
if n != len(ir.chunks[0]) {
|
||||||
|
panic("this test reader doesn't support partial reads of chunks")
|
||||||
|
}
|
||||||
|
|
||||||
|
ir.chunks = ir.chunks[1:]
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ir *interruptReader) push(p []byte) {
|
||||||
|
ir.chunks = append(ir.chunks, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrontendReceiveInterrupted(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push([]byte{'Z', 0, 0, 0, 5})
|
||||||
|
|
||||||
|
frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil)
|
||||||
|
|
||||||
|
msg, err := frontend.Receive()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected err")
|
||||||
|
}
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("did not expect msg, but %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.push([]byte{'I'})
|
||||||
|
|
||||||
|
msg, err = frontend.Receive()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' {
|
||||||
|
t.Fatalf("unexpected msg: %v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrontendReceiveUnexpectedEOF(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push([]byte{'Z', 0, 0, 0, 5})
|
||||||
|
|
||||||
|
frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil)
|
||||||
|
|
||||||
|
msg, err := frontend.Receive()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected err")
|
||||||
|
}
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("did not expect msg, but %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err = frontend.Receive()
|
||||||
|
assert.Nil(t, msg)
|
||||||
|
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrorResponse(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
want := &pgproto3.ErrorResponse{
|
||||||
|
Severity: "ERROR",
|
||||||
|
SeverityUnlocalized: "ERROR",
|
||||||
|
Message: `column "foo" does not exist`,
|
||||||
|
File: "parse_relation.c",
|
||||||
|
Code: "42703",
|
||||||
|
Position: 8,
|
||||||
|
Line: 3513,
|
||||||
|
Routine: "errorMissingColumn",
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := []byte{
|
||||||
|
'E', 0, 0, 0, 'f',
|
||||||
|
'S', 'E', 'R', 'R', 'O', 'R', 0,
|
||||||
|
'V', 'E', 'R', 'R', 'O', 'R', 0,
|
||||||
|
'C', '4', '2', '7', '0', '3', 0,
|
||||||
|
'M', 'c', 'o', 'l', 'u', 'm', 'n', 32, '"', 'f', 'o', 'o', '"', 32, 'd', 'o', 'e', 's', 32, 'n', 'o', 't', 32, 'e', 'x', 'i', 's', 't', 0,
|
||||||
|
'P', '8', 0,
|
||||||
|
'F', 'p', 'a', 'r', 's', 'e', '_', 'r', 'e', 'l', 'a', 't', 'i', 'o', 'n', '.', 'c', 0,
|
||||||
|
'L', '3', '5', '1', '3', 0,
|
||||||
|
'R', 'e', 'r', 'r', 'o', 'r', 'M', 'i', 's', 's', 'i', 'n', 'g', 'C', 'o', 'l', 'u', 'm', 'n', 0, 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push(raw)
|
||||||
|
|
||||||
|
frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil)
|
||||||
|
|
||||||
|
got, err := frontend.Receive()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Function uint32
|
||||||
|
ArgFormatCodes []uint16
|
||||||
|
Arguments [][]byte
|
||||||
|
ResultFormatCode uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*FunctionCall) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *FunctionCall) Decode(src []byte) error {
|
||||||
|
*dst = FunctionCall{}
|
||||||
|
rp := 0
|
||||||
|
// Specifies the object ID of the function to call.
|
||||||
|
dst.Function = binary.BigEndian.Uint32(src[rp:])
|
||||||
|
rp += 4
|
||||||
|
// The number of argument format codes that follow (denoted C below).
|
||||||
|
// This can be zero to indicate that there are no arguments or that the arguments all use the default format (text);
|
||||||
|
// or one, in which case the specified format code is applied to all arguments;
|
||||||
|
// or it can equal the actual number of arguments.
|
||||||
|
nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
argumentCodes := make([]uint16, nArgumentCodes)
|
||||||
|
for i := 0; i < nArgumentCodes; i++ {
|
||||||
|
// The argument format codes. Each must presently be zero (text) or one (binary).
|
||||||
|
ac := binary.BigEndian.Uint16(src[rp:])
|
||||||
|
if ac != 0 && ac != 1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||||
|
}
|
||||||
|
argumentCodes[i] = ac
|
||||||
|
rp += 2
|
||||||
|
}
|
||||||
|
dst.ArgFormatCodes = argumentCodes
|
||||||
|
|
||||||
|
// Specifies the number of arguments being supplied to the function.
|
||||||
|
nArguments := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
arguments := make([][]byte, nArguments)
|
||||||
|
for i := 0; i < nArguments; i++ {
|
||||||
|
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
|
||||||
|
// As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case.
|
||||||
|
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
|
||||||
|
rp += 4
|
||||||
|
if argumentLength == -1 {
|
||||||
|
arguments[i] = nil
|
||||||
|
} else {
|
||||||
|
// The value of the argument, in the format indicated by the associated format code. n is the above length.
|
||||||
|
argumentValue := src[rp : rp+argumentLength]
|
||||||
|
rp += argumentLength
|
||||||
|
arguments[i] = argumentValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst.Arguments = arguments
|
||||||
|
// The format code for the function result. Must presently be zero (text) or one (binary).
|
||||||
|
resultFormatCode := binary.BigEndian.Uint16(src[rp:])
|
||||||
|
if resultFormatCode != 0 && resultFormatCode != 1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||||
|
}
|
||||||
|
dst.ResultFormatCode = resultFormatCode
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *FunctionCall) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'F')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
|
||||||
|
dst = pgio.AppendUint32(dst, src.Function)
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
|
||||||
|
for _, argFormatCode := range src.ArgFormatCodes {
|
||||||
|
dst = pgio.AppendUint16(dst, argFormatCode)
|
||||||
|
}
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
|
||||||
|
for _, argument := range src.Arguments {
|
||||||
|
if argument == nil {
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
} else {
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(argument)))
|
||||||
|
dst = append(dst, argument...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
return dst
|
||||||
|
}
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FunctionCallResponse struct {
|
||||||
|
Result []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*FunctionCallResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *FunctionCallResponse) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||||
|
}
|
||||||
|
rp := 0
|
||||||
|
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
if resultSize == -1 {
|
||||||
|
dst.Result = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) != resultSize {
|
||||||
|
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Result = src[rp:]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'V')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
if src.Result == nil {
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
} else {
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(src.Result)))
|
||||||
|
dst = append(dst, src.Result...)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src FunctionCallResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
var formattedValue map[string]string
|
||||||
|
var hasNonPrintable bool
|
||||||
|
for _, b := range src.Result {
|
||||||
|
if b < 32 {
|
||||||
|
hasNonPrintable = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasNonPrintable {
|
||||||
|
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
|
||||||
|
} else {
|
||||||
|
formattedValue = map[string]string{"text": string(src.Result)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Result map[string]string
|
||||||
|
}{
|
||||||
|
Type: "FunctionCallResponse",
|
||||||
|
Result: formattedValue,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Result map[string]string
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(data, &msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Result, err = getValueFromJSON(msg.Result)
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFunctionCall_EncodeDecode(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Function uint32
|
||||||
|
ArgFormatCodes []uint16
|
||||||
|
Arguments [][]byte
|
||||||
|
ResultFormatCode uint16
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"valid", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(1)}, false},
|
||||||
|
{"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true},
|
||||||
|
{"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
src := &FunctionCall{
|
||||||
|
Function: tt.fields.Function,
|
||||||
|
ArgFormatCodes: tt.fields.ArgFormatCodes,
|
||||||
|
Arguments: tt.fields.Arguments,
|
||||||
|
ResultFormatCode: tt.fields.ResultFormatCode,
|
||||||
|
}
|
||||||
|
encoded := src.Encode([]byte{})
|
||||||
|
dst := &FunctionCall{}
|
||||||
|
// Check the header
|
||||||
|
msgTypeCode := encoded[0]
|
||||||
|
if msgTypeCode != 'F' {
|
||||||
|
t.Errorf("msgTypeCode %v should be 'F'", msgTypeCode)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check length, does not include type code character
|
||||||
|
l := binary.BigEndian.Uint32(encoded[1:5])
|
||||||
|
if int(l) != (len(encoded) - 1) {
|
||||||
|
t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded))
|
||||||
|
}
|
||||||
|
// Check decoding works as expected
|
||||||
|
err := dst.Decode(encoded[5:])
|
||||||
|
if err != nil {
|
||||||
|
if !tt.wantErr {
|
||||||
|
t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(src, dst) {
|
||||||
|
t.Error("difference after encode / decode cycle")
|
||||||
|
t.Errorf("src = %v", src)
|
||||||
|
t.Errorf("dst = %v", dst)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
module github.com/jackc/pgproto3/v2
|
||||||
|
|
||||||
|
go 1.12
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.0
|
||||||
|
github.com/jackc/pgio v1.0.0
|
||||||
|
github.com/stretchr/testify v1.4.0
|
||||||
|
)
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs=
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||||
|
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
||||||
|
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||||
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const gssEncReqNumber = 80877104
|
||||||
|
|
||||||
|
type GSSEncRequest struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*GSSEncRequest) Frontend() {}
|
||||||
|
|
||||||
|
func (dst *GSSEncRequest) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("gss encoding request too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCode := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if requestCode != gssEncReqNumber {
|
||||||
|
return errors.New("bad gss encoding request code")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
|
func (src *GSSEncRequest) Encode(dst []byte) []byte {
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendInt32(dst, gssEncReqNumber)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src GSSEncRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}{
|
||||||
|
Type: "GSSEncRequest",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,572 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJSONUnmarshalAuthenticationMD5Password(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"AuthenticationMD5Password", "Salt":[97,98,99,100]}`)
|
||||||
|
want := AuthenticationMD5Password{
|
||||||
|
Salt: [4]byte{'a', 'b', 'c', 'd'},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got AuthenticationMD5Password
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalAuthenticationSASL(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"AuthenticationSASL","AuthMechanisms":["SCRAM-SHA-256"]}`)
|
||||||
|
want := AuthenticationSASL{
|
||||||
|
[]string{"SCRAM-SHA-256"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got AuthenticationSASL
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled AuthenticationSASL struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`)
|
||||||
|
want := AuthenticationSASLContinue{
|
||||||
|
Data: []byte{'1'},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got AuthenticationSASLContinue
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled AuthenticationSASLContinue struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalAuthenticationSASLFinal(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"AuthenticationSASLFinal", "Data":"1"}`)
|
||||||
|
want := AuthenticationSASLFinal{
|
||||||
|
Data: []byte{'1'},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got AuthenticationSASLFinal
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled AuthenticationSASLFinal struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalBackendKeyData(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"BackendKeyData","ProcessID":8864,"SecretKey":3641487067}`)
|
||||||
|
want := BackendKeyData{
|
||||||
|
ProcessID: 8864,
|
||||||
|
SecretKey: 3641487067,
|
||||||
|
}
|
||||||
|
|
||||||
|
var got BackendKeyData
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled BackendKeyData struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalCommandComplete(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"CommandComplete","CommandTag":"SELECT 1"}`)
|
||||||
|
want := CommandComplete{
|
||||||
|
CommandTag: []byte("SELECT 1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var got CommandComplete
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled CommandComplete struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalCopyBothResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`)
|
||||||
|
want := CopyBothResponse{
|
||||||
|
OverallFormat: 'W',
|
||||||
|
}
|
||||||
|
|
||||||
|
var got CopyBothResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled CopyBothResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalCopyData(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"CopyData"}`)
|
||||||
|
want := CopyData{
|
||||||
|
Data: []byte{},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got CopyData
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled CopyData struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalCopyInResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"CopyBothResponse", "OverallFormat": "W"}`)
|
||||||
|
want := CopyBothResponse{
|
||||||
|
OverallFormat: 'W',
|
||||||
|
}
|
||||||
|
|
||||||
|
var got CopyBothResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled CopyBothResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalCopyOutResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"CopyOutResponse", "OverallFormat": "W"}`)
|
||||||
|
want := CopyOutResponse{
|
||||||
|
OverallFormat: 'W',
|
||||||
|
}
|
||||||
|
|
||||||
|
var got CopyOutResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled CopyOutResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalDataRow(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"DataRow","Values":[{"text":"abc"},{"text":"this is a test"},{"binary":"000263d3114d2e34"}]}`)
|
||||||
|
want := DataRow{
|
||||||
|
Values: [][]byte{
|
||||||
|
[]byte("abc"),
|
||||||
|
[]byte("this is a test"),
|
||||||
|
{0, 2, 99, 211, 17, 77, 46, 52},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got DataRow
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled DataRow struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalErrorResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"ErrorResponse", "UnknownFields": {"97": "foo"}}`)
|
||||||
|
want := ErrorResponse{
|
||||||
|
UnknownFields: map[byte]string{
|
||||||
|
'a': "foo",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got ErrorResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled ErrorResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalFunctionCallResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"FunctionCallResponse"}`)
|
||||||
|
want := FunctionCallResponse{}
|
||||||
|
|
||||||
|
var got FunctionCallResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled FunctionCallResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalNoticeResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"NoticeResponse", "UnknownFields": {"97": "foo"}}`)
|
||||||
|
want := NoticeResponse{
|
||||||
|
UnknownFields: map[byte]string{
|
||||||
|
'a': "foo",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got NoticeResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled NoticeResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalNotificationResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"NotificationResponse"}`)
|
||||||
|
want := NotificationResponse{}
|
||||||
|
|
||||||
|
var got NotificationResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled NotificationResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalParameterDescription(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"ParameterDescription", "ParameterOIDs": [25]}`)
|
||||||
|
want := ParameterDescription{
|
||||||
|
ParameterOIDs: []uint32{25},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got ParameterDescription
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled ParameterDescription struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalParameterStatus(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"ParameterStatus","Name":"TimeZone","Value":"Europe/Amsterdam"}`)
|
||||||
|
want := ParameterStatus{
|
||||||
|
Name: "TimeZone",
|
||||||
|
Value: "Europe/Amsterdam",
|
||||||
|
}
|
||||||
|
|
||||||
|
var got ParameterStatus
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled ParameterDescription struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalReadyForQuery(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"ReadyForQuery","TxStatus":"I"}`)
|
||||||
|
want := ReadyForQuery{
|
||||||
|
TxStatus: 'I',
|
||||||
|
}
|
||||||
|
|
||||||
|
var got ReadyForQuery
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled ParameterDescription struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalRowDescription(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"RowDescription","Fields":[{"Name":"generate_series","TableOID":0,"TableAttributeNumber":0,"DataTypeOID":23,"DataTypeSize":4,"TypeModifier":-1,"Format":0}]}`)
|
||||||
|
want := RowDescription{
|
||||||
|
Fields: []FieldDescription{
|
||||||
|
{
|
||||||
|
Name: []byte("generate_series"),
|
||||||
|
DataTypeOID: 23,
|
||||||
|
DataTypeSize: 4,
|
||||||
|
TypeModifier: -1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got RowDescription
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled RowDescription struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalBind(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
desc string
|
||||||
|
data []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"textual",
|
||||||
|
[]byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"text":"ABC-123"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"binary",
|
||||||
|
[]byte(`{"Type":"Bind","DestinationPortal":"","PreparedStatement":"lrupsc_1_0","ParameterFormatCodes":[0],"Parameters":[{"binary":"` + hex.EncodeToString([]byte("ABC-123")) + `"}],"ResultFormatCodes":[0,0,0,0,0,1,1]}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
var want = Bind{
|
||||||
|
PreparedStatement: "lrupsc_1_0",
|
||||||
|
ParameterFormatCodes: []int16{0},
|
||||||
|
Parameters: [][]byte{[]byte("ABC-123")},
|
||||||
|
ResultFormatCodes: []int16{0, 0, 0, 0, 0, 1, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got Bind
|
||||||
|
if err := json.Unmarshal(tc.data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled Bind struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalCancelRequest(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"CancelRequest","ProcessID":8864,"SecretKey":3641487067}`)
|
||||||
|
want := CancelRequest{
|
||||||
|
ProcessID: 8864,
|
||||||
|
SecretKey: 3641487067,
|
||||||
|
}
|
||||||
|
|
||||||
|
var got CancelRequest
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled CancelRequest struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalClose(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"Close","ObjectType":"S","Name":"abc"}`)
|
||||||
|
want := Close{
|
||||||
|
ObjectType: 'S',
|
||||||
|
Name: "abc",
|
||||||
|
}
|
||||||
|
|
||||||
|
var got Close
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled Close struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalCopyFail(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"CopyFail","Message":"abc"}`)
|
||||||
|
want := CopyFail{
|
||||||
|
Message: "abc",
|
||||||
|
}
|
||||||
|
|
||||||
|
var got CopyFail
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled CopyFail struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalDescribe(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"Describe","ObjectType":"S","Name":"abc"}`)
|
||||||
|
want := Describe{
|
||||||
|
ObjectType: 'S',
|
||||||
|
Name: "abc",
|
||||||
|
}
|
||||||
|
|
||||||
|
var got Describe
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled Describe struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalExecute(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"Execute","Portal":"","MaxRows":0}`)
|
||||||
|
want := Execute{}
|
||||||
|
|
||||||
|
var got Execute
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled Execute struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalParse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"Parse","Name":"lrupsc_1_0","Query":"SELECT id, name FROM t WHERE id = $1","ParameterOIDs":null}`)
|
||||||
|
want := Parse{
|
||||||
|
Name: "lrupsc_1_0",
|
||||||
|
Query: "SELECT id, name FROM t WHERE id = $1",
|
||||||
|
}
|
||||||
|
|
||||||
|
var got Parse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled Parse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalPasswordMessage(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"PasswordMessage","Password":"abcdef"}`)
|
||||||
|
want := PasswordMessage{
|
||||||
|
Password: "abcdef",
|
||||||
|
}
|
||||||
|
|
||||||
|
var got PasswordMessage
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled PasswordMessage struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalQuery(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"Query","String":"SELECT 1"}`)
|
||||||
|
want := Query{
|
||||||
|
String: "SELECT 1",
|
||||||
|
}
|
||||||
|
|
||||||
|
var got Query
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled Query struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalSASLInitialResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"SASLInitialResponse", "AuthMechanism":"SCRAM-SHA-256", "Data": "6D"}`)
|
||||||
|
want := SASLInitialResponse{
|
||||||
|
AuthMechanism: "SCRAM-SHA-256",
|
||||||
|
Data: []byte{109},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got SASLInitialResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled SASLInitialResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalSASLResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"SASLResponse","Message":"abc"}`)
|
||||||
|
want := SASLResponse{}
|
||||||
|
|
||||||
|
var got SASLResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled SASLResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalStartupMessage(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"StartupMessage","ProtocolVersion":196608,"Parameters":{"database":"testing","user":"postgres"}}`)
|
||||||
|
want := StartupMessage{
|
||||||
|
ProtocolVersion: 196608,
|
||||||
|
Parameters: map[string]string{
|
||||||
|
"database": "testing",
|
||||||
|
"user": "postgres",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got StartupMessage
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled StartupMessage struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticationOK(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"AuthenticationOK"}`)
|
||||||
|
want := AuthenticationOk{}
|
||||||
|
|
||||||
|
var got AuthenticationOk
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled AuthenticationOK struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticationCleartextPassword(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"AuthenticationCleartextPassword"}`)
|
||||||
|
want := AuthenticationCleartextPassword{}
|
||||||
|
|
||||||
|
var got AuthenticationCleartextPassword
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled AuthenticationCleartextPassword struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticationMD5Password(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"AuthenticationMD5Password","Salt":[1,2,3,4]}`)
|
||||||
|
want := AuthenticationMD5Password{
|
||||||
|
Salt: [4]byte{1, 2, 3, 4},
|
||||||
|
}
|
||||||
|
|
||||||
|
var got AuthenticationMD5Password
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrorResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`)
|
||||||
|
want := ErrorResponse{
|
||||||
|
UnknownFields: map[byte]string{
|
||||||
|
'p': "foo",
|
||||||
|
},
|
||||||
|
Code: "Fail",
|
||||||
|
Position: 1,
|
||||||
|
Message: "this is an error",
|
||||||
|
}
|
||||||
|
|
||||||
|
var got ErrorResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled ErrorResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NoData struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*NoData) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *NoData) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *NoData) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'n', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src NoData) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "NoData",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
type NoticeResponse ErrorResponse
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*NoticeResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *NoticeResponse) Decode(src []byte) error {
|
||||||
|
return (*ErrorResponse)(dst).Decode(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *NoticeResponse) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NotificationResponse struct {
|
||||||
|
PID uint32
|
||||||
|
Channel string
|
||||||
|
Payload string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*NotificationResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *NotificationResponse) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
pid := binary.BigEndian.Uint32(buf.Next(4))
|
||||||
|
|
||||||
|
b, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
channel := string(b[:len(b)-1])
|
||||||
|
|
||||||
|
b, err = buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := string(b[:len(b)-1])
|
||||||
|
|
||||||
|
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *NotificationResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'A')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint32(dst, src.PID)
|
||||||
|
dst = append(dst, src.Channel...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, src.Payload...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src NotificationResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
PID uint32
|
||||||
|
Channel string
|
||||||
|
Payload string
|
||||||
|
}{
|
||||||
|
Type: "NotificationResponse",
|
||||||
|
PID: src.PID,
|
||||||
|
Channel: src.Channel,
|
||||||
|
Payload: src.Payload,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ParameterDescription struct {
|
||||||
|
ParameterOIDs []uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*ParameterDescription) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *ParameterDescription) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
if buf.Len() < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reported parameter count will be incorrect when number of args is greater than uint16
|
||||||
|
buf.Next(2)
|
||||||
|
// Instead infer parameter count by remaining size of message
|
||||||
|
parameterCount := buf.Len() / 4
|
||||||
|
|
||||||
|
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
|
||||||
|
|
||||||
|
for i := 0; i < parameterCount; i++ {
|
||||||
|
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *ParameterDescription) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 't')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||||
|
for _, oid := range src.ParameterOIDs {
|
||||||
|
dst = pgio.AppendUint32(dst, oid)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src ParameterDescription) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ParameterOIDs []uint32
|
||||||
|
}{
|
||||||
|
Type: "ParameterDescription",
|
||||||
|
ParameterOIDs: src.ParameterOIDs,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ParameterStatus struct {
|
||||||
|
Name string
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*ParameterStatus) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *ParameterStatus) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
b, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
name := string(b[:len(b)-1])
|
||||||
|
|
||||||
|
b, err = buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
value := string(b[:len(b)-1])
|
||||||
|
|
||||||
|
*dst = ParameterStatus{Name: name, Value: value}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *ParameterStatus) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'S')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, src.Value...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (ps ParameterStatus) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Name string
|
||||||
|
Value string
|
||||||
|
}{
|
||||||
|
Type: "ParameterStatus",
|
||||||
|
Name: ps.Name,
|
||||||
|
Value: ps.Value,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Parse struct {
|
||||||
|
Name string
|
||||||
|
Query string
|
||||||
|
ParameterOIDs []uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Parse) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Parse) Decode(src []byte) error {
|
||||||
|
*dst = Parse{}
|
||||||
|
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
b, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Name = string(b[:len(b)-1])
|
||||||
|
|
||||||
|
b, err = buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Query = string(b[:len(b)-1])
|
||||||
|
|
||||||
|
if buf.Len() < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||||
|
}
|
||||||
|
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||||
|
|
||||||
|
for i := 0; i < parameterOIDCount; i++ {
|
||||||
|
if buf.Len() < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||||
|
}
|
||||||
|
dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Parse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'P')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, src.Query...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||||
|
for _, oid := range src.ParameterOIDs {
|
||||||
|
dst = pgio.AppendUint32(dst, oid)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Parse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Name string
|
||||||
|
Query string
|
||||||
|
ParameterOIDs []uint32
|
||||||
|
}{
|
||||||
|
Type: "Parse",
|
||||||
|
Name: src.Name,
|
||||||
|
Query: src.Query,
|
||||||
|
ParameterOIDs: src.ParameterOIDs,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ParseComplete struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*ParseComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *ParseComplete) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *ParseComplete) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, '1', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src ParseComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "ParseComplete",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PasswordMessage struct {
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*PasswordMessage) Frontend() {}
|
||||||
|
|
||||||
|
// Frontend identifies this message as an authentication response.
|
||||||
|
func (*PasswordMessage) InitialResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *PasswordMessage) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
b, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Password = string(b[:len(b)-1])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *PasswordMessage) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'p')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
|
||||||
|
|
||||||
|
dst = append(dst, src.Password...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src PasswordMessage) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Password string
|
||||||
|
}{
|
||||||
|
Type: "PasswordMessage",
|
||||||
|
Password: src.Password,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message is the interface implemented by an object that can decode and encode
|
||||||
|
// a particular PostgreSQL message.
|
||||||
|
type Message interface {
|
||||||
|
// Decode is allowed and expected to retain a reference to data after
|
||||||
|
// returning (unlike encoding.BinaryUnmarshaler).
|
||||||
|
Decode(data []byte) error
|
||||||
|
|
||||||
|
// Encode appends itself to dst and returns the new buffer.
|
||||||
|
Encode(dst []byte) []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type FrontendMessage interface {
|
||||||
|
Message
|
||||||
|
Frontend() // no-op method to distinguish frontend from backend methods
|
||||||
|
}
|
||||||
|
|
||||||
|
type BackendMessage interface {
|
||||||
|
Message
|
||||||
|
Backend() // no-op method to distinguish frontend from backend methods
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthenticationResponseMessage interface {
|
||||||
|
BackendMessage
|
||||||
|
AuthenticationResponse() // no-op method to distinguish authentication responses
|
||||||
|
}
|
||||||
|
|
||||||
|
type invalidMessageLenErr struct {
|
||||||
|
messageType string
|
||||||
|
expectedLen int
|
||||||
|
actualLen int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidMessageLenErr) Error() string {
|
||||||
|
return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
type invalidMessageFormatErr struct {
|
||||||
|
messageType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidMessageFormatErr) Error() string {
|
||||||
|
return fmt.Sprintf("%s body is invalid", e.messageType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getValueFromJSON gets the value from a protocol message representation in JSON.
|
||||||
|
func getValueFromJSON(v map[string]string) ([]byte, error) {
|
||||||
|
if v == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if text, ok := v["text"]; ok {
|
||||||
|
return []byte(text), nil
|
||||||
|
}
|
||||||
|
if binary, ok := v["binary"]; ok {
|
||||||
|
return hex.DecodeString(binary)
|
||||||
|
}
|
||||||
|
return nil, errors.New("unknown protocol representation")
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PortalSuspended struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*PortalSuspended) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *PortalSuspended) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "PortalSuspended", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *PortalSuspended) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 's', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src PortalSuspended) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "PortalSuspended",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Query struct {
|
||||||
|
String string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Query) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Query) Decode(src []byte) error {
|
||||||
|
i := bytes.IndexByte(src, 0)
|
||||||
|
if i != len(src)-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Query"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.String = string(src[:i])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Query) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'Q')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
|
||||||
|
|
||||||
|
dst = append(dst, src.String...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Query) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
String string
|
||||||
|
}{
|
||||||
|
Type: "Query",
|
||||||
|
String: src.String,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ReadyForQuery struct {
|
||||||
|
TxStatus byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*ReadyForQuery) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *ReadyForQuery) Decode(src []byte) error {
|
||||||
|
if len(src) != 1 {
|
||||||
|
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.TxStatus = src[0]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *ReadyForQuery) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src ReadyForQuery) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
TxStatus string
|
||||||
|
}{
|
||||||
|
Type: "ReadyForQuery",
|
||||||
|
TxStatus: string(src.TxStatus),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
TxStatus string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(msg.TxStatus) != 1 {
|
||||||
|
return errors.New("invalid length for ReadyForQuery.TxStatus")
|
||||||
|
}
|
||||||
|
dst.TxStatus = msg.TxStatus[0]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,165 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TextFormat = 0
|
||||||
|
BinaryFormat = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
type FieldDescription struct {
|
||||||
|
Name []byte
|
||||||
|
TableOID uint32
|
||||||
|
TableAttributeNumber uint16
|
||||||
|
DataTypeOID uint32
|
||||||
|
DataTypeSize int16
|
||||||
|
TypeModifier int32
|
||||||
|
Format int16
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (fd FieldDescription) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Name string
|
||||||
|
TableOID uint32
|
||||||
|
TableAttributeNumber uint16
|
||||||
|
DataTypeOID uint32
|
||||||
|
DataTypeSize int16
|
||||||
|
TypeModifier int32
|
||||||
|
Format int16
|
||||||
|
}{
|
||||||
|
Name: string(fd.Name),
|
||||||
|
TableOID: fd.TableOID,
|
||||||
|
TableAttributeNumber: fd.TableAttributeNumber,
|
||||||
|
DataTypeOID: fd.DataTypeOID,
|
||||||
|
DataTypeSize: fd.DataTypeSize,
|
||||||
|
TypeModifier: fd.TypeModifier,
|
||||||
|
Format: fd.Format,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type RowDescription struct {
|
||||||
|
Fields []FieldDescription
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*RowDescription) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *RowDescription) Decode(src []byte) error {
|
||||||
|
|
||||||
|
if len(src) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||||
|
}
|
||||||
|
fieldCount := int(binary.BigEndian.Uint16(src))
|
||||||
|
rp := 2
|
||||||
|
|
||||||
|
dst.Fields = dst.Fields[0:0]
|
||||||
|
|
||||||
|
for i := 0; i < fieldCount; i++ {
|
||||||
|
var fd FieldDescription
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||||
|
}
|
||||||
|
fd.Name = src[rp : rp+idx]
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
// Since buf.Next() doesn't return an error if we hit the end of the buffer
|
||||||
|
// check Len ahead of time
|
||||||
|
if len(src[rp:]) < 18 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||||
|
}
|
||||||
|
|
||||||
|
fd.TableOID = binary.BigEndian.Uint32(src[rp:])
|
||||||
|
rp += 4
|
||||||
|
fd.TableAttributeNumber = binary.BigEndian.Uint16(src[rp:])
|
||||||
|
rp += 2
|
||||||
|
fd.DataTypeOID = binary.BigEndian.Uint32(src[rp:])
|
||||||
|
rp += 4
|
||||||
|
fd.DataTypeSize = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
fd.TypeModifier = int32(binary.BigEndian.Uint32(src[rp:]))
|
||||||
|
rp += 4
|
||||||
|
fd.Format = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
dst.Fields = append(dst.Fields, fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *RowDescription) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'T')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
|
||||||
|
for _, fd := range src.Fields {
|
||||||
|
dst = append(dst, fd.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint32(dst, fd.TableOID)
|
||||||
|
dst = pgio.AppendUint16(dst, fd.TableAttributeNumber)
|
||||||
|
dst = pgio.AppendUint32(dst, fd.DataTypeOID)
|
||||||
|
dst = pgio.AppendInt16(dst, fd.DataTypeSize)
|
||||||
|
dst = pgio.AppendInt32(dst, fd.TypeModifier)
|
||||||
|
dst = pgio.AppendInt16(dst, fd.Format)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src RowDescription) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Fields []FieldDescription
|
||||||
|
}{
|
||||||
|
Type: "RowDescription",
|
||||||
|
Fields: src.Fields,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *RowDescription) UnmarshalJSON(data []byte) error {
|
||||||
|
var msg struct {
|
||||||
|
Fields []struct {
|
||||||
|
Name string
|
||||||
|
TableOID uint32
|
||||||
|
TableAttributeNumber uint16
|
||||||
|
DataTypeOID uint32
|
||||||
|
DataTypeSize int16
|
||||||
|
TypeModifier int32
|
||||||
|
Format int16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Fields = make([]FieldDescription, len(msg.Fields))
|
||||||
|
for n, field := range msg.Fields {
|
||||||
|
dst.Fields[n] = FieldDescription{
|
||||||
|
Name: []byte(field.Name),
|
||||||
|
TableOID: field.TableOID,
|
||||||
|
TableAttributeNumber: field.TableAttributeNumber,
|
||||||
|
DataTypeOID: field.DataTypeOID,
|
||||||
|
DataTypeSize: field.DataTypeSize,
|
||||||
|
TypeModifier: field.TypeModifier,
|
||||||
|
Format: field.Format,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SASLInitialResponse struct {
|
||||||
|
AuthMechanism string
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*SASLInitialResponse) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *SASLInitialResponse) Decode(src []byte) error {
|
||||||
|
*dst = SASLInitialResponse{}
|
||||||
|
|
||||||
|
rp := 0
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return errors.New("invalid SASLInitialResponse")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.AuthMechanism = string(src[rp:idx])
|
||||||
|
rp = idx + 1
|
||||||
|
|
||||||
|
rp += 4 // The rest of the message is data so we can just skip the size
|
||||||
|
dst.Data = src[rp:]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *SASLInitialResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'p')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, []byte(src.AuthMechanism)...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src SASLInitialResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
AuthMechanism string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "SASLInitialResponse",
|
||||||
|
AuthMechanism: src.AuthMechanism,
|
||||||
|
Data: string(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
AuthMechanism string
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.AuthMechanism = msg.AuthMechanism
|
||||||
|
if msg.Data != "" {
|
||||||
|
decoded, err := hex.DecodeString(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Data = decoded
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SASLResponse struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*SASLResponse) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *SASLResponse) Decode(src []byte) error {
|
||||||
|
*dst = SASLResponse{Data: src}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *SASLResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'p')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||||
|
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src SASLResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "SASLResponse",
|
||||||
|
Data: string(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *SASLResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Data != "" {
|
||||||
|
decoded, err := hex.DecodeString(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Data = decoded
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sslRequestNumber = 80877103
|
||||||
|
|
||||||
|
type SSLRequest struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*SSLRequest) Frontend() {}
|
||||||
|
|
||||||
|
func (dst *SSLRequest) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("ssl request too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCode := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if requestCode != sslRequestNumber {
|
||||||
|
return errors.New("bad ssl request code")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
|
func (src *SSLRequest) Encode(dst []byte) []byte {
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendInt32(dst, sslRequestNumber)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src SSLRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}{
|
||||||
|
Type: "SSLRequest",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ProtocolVersionNumber = 196608 // 3.0
|
||||||
|
|
||||||
|
type StartupMessage struct {
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*StartupMessage) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *StartupMessage) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("startup message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
||||||
|
rp := 4
|
||||||
|
|
||||||
|
if dst.ProtocolVersion != ProtocolVersionNumber {
|
||||||
|
return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Parameters = make(map[string]string)
|
||||||
|
for {
|
||||||
|
idx := bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||||
|
}
|
||||||
|
key := string(src[rp : rp+idx])
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
idx = bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||||
|
}
|
||||||
|
value := string(src[rp : rp+idx])
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
dst.Parameters[key] = value
|
||||||
|
|
||||||
|
if len(src[rp:]) == 1 {
|
||||||
|
if src[rp] != 0 {
|
||||||
|
return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *StartupMessage) Encode(dst []byte) []byte {
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint32(dst, src.ProtocolVersion)
|
||||||
|
for k, v := range src.Parameters {
|
||||||
|
dst = append(dst, k...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, v...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
}
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src StartupMessage) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}{
|
||||||
|
Type: "StartupMessage",
|
||||||
|
ProtocolVersion: src.ProtocolVersion,
|
||||||
|
Parameters: src.Parameters,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sync struct{}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Sync) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Sync) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Sync) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'S', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Sync) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "Sync",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Terminate struct{}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Terminate) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Terminate) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Terminate) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'X', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Terminate) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "Terminate",
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user