5714896b10
Use an internal buffer in pgproto3.Frontend and pgproto3.Backend instead of directly writing to the underlying net.Conn. This will allow tracing messages as well as simplify pipeline mode.
225 lines
5.8 KiB
Go
225 lines
5.8 KiB
Go
package pgproto3
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
)
|
|
|
|
// Frontend acts as a client for the PostgreSQL wire protocol version 3.
|
|
type Frontend struct {
|
|
cr *chunkReader
|
|
w io.Writer
|
|
|
|
wbuf []byte
|
|
|
|
// Backend message flyweights
|
|
authenticationOk AuthenticationOk
|
|
authenticationCleartextPassword AuthenticationCleartextPassword
|
|
authenticationMD5Password AuthenticationMD5Password
|
|
authenticationGSS AuthenticationGSS
|
|
authenticationGSSContinue AuthenticationGSSContinue
|
|
authenticationSASL AuthenticationSASL
|
|
authenticationSASLContinue AuthenticationSASLContinue
|
|
authenticationSASLFinal AuthenticationSASLFinal
|
|
backendKeyData BackendKeyData
|
|
bindComplete BindComplete
|
|
closeComplete CloseComplete
|
|
commandComplete CommandComplete
|
|
copyBothResponse CopyBothResponse
|
|
copyData CopyData
|
|
copyInResponse CopyInResponse
|
|
copyOutResponse CopyOutResponse
|
|
copyDone CopyDone
|
|
dataRow DataRow
|
|
emptyQueryResponse EmptyQueryResponse
|
|
errorResponse ErrorResponse
|
|
functionCallResponse FunctionCallResponse
|
|
noData NoData
|
|
noticeResponse NoticeResponse
|
|
notificationResponse NotificationResponse
|
|
parameterDescription ParameterDescription
|
|
parameterStatus ParameterStatus
|
|
parseComplete ParseComplete
|
|
readyForQuery ReadyForQuery
|
|
rowDescription RowDescription
|
|
portalSuspended PortalSuspended
|
|
|
|
bodyLen int
|
|
msgType byte
|
|
partialMsg bool
|
|
authType uint32
|
|
}
|
|
|
|
// NewFrontend creates a new Frontend.
|
|
func NewFrontend(r io.Reader, w io.Writer) *Frontend {
|
|
cr := newChunkReader(r, 0)
|
|
return &Frontend{cr: cr, w: w}
|
|
}
|
|
|
|
// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
|
|
// called.
|
|
func (f *Frontend) Send(msg FrontendMessage) {
|
|
f.wbuf = msg.Encode(f.wbuf)
|
|
}
|
|
|
|
// Flush writes any pending messages to the backend (i.e. the server).
|
|
func (f *Frontend) Flush() error {
|
|
n, err := f.w.Write(f.wbuf)
|
|
|
|
const maxLen = 1024
|
|
if len(f.wbuf) > maxLen {
|
|
f.wbuf = make([]byte, 0, maxLen)
|
|
} else {
|
|
f.wbuf = f.wbuf[:0]
|
|
}
|
|
|
|
if err != nil {
|
|
return &writeError{err: err, safeToRetry: n == 0}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func translateEOFtoErrUnexpectedEOF(err error) error {
|
|
if err == io.EOF {
|
|
return io.ErrUnexpectedEOF
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
|
|
func (f *Frontend) Receive() (BackendMessage, error) {
|
|
if !f.partialMsg {
|
|
header, err := f.cr.Next(5)
|
|
if err != nil {
|
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
|
}
|
|
|
|
f.msgType = header[0]
|
|
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
|
f.partialMsg = true
|
|
}
|
|
|
|
msgBody, err := f.cr.Next(f.bodyLen)
|
|
if err != nil {
|
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
|
}
|
|
|
|
f.partialMsg = false
|
|
|
|
var msg BackendMessage
|
|
switch f.msgType {
|
|
case '1':
|
|
msg = &f.parseComplete
|
|
case '2':
|
|
msg = &f.bindComplete
|
|
case '3':
|
|
msg = &f.closeComplete
|
|
case 'A':
|
|
msg = &f.notificationResponse
|
|
case 'c':
|
|
msg = &f.copyDone
|
|
case 'C':
|
|
msg = &f.commandComplete
|
|
case 'd':
|
|
msg = &f.copyData
|
|
case 'D':
|
|
msg = &f.dataRow
|
|
case 'E':
|
|
msg = &f.errorResponse
|
|
case 'G':
|
|
msg = &f.copyInResponse
|
|
case 'H':
|
|
msg = &f.copyOutResponse
|
|
case 'I':
|
|
msg = &f.emptyQueryResponse
|
|
case 'K':
|
|
msg = &f.backendKeyData
|
|
case 'n':
|
|
msg = &f.noData
|
|
case 'N':
|
|
msg = &f.noticeResponse
|
|
case 'R':
|
|
var err error
|
|
msg, err = f.findAuthenticationMessageType(msgBody)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case 's':
|
|
msg = &f.portalSuspended
|
|
case 'S':
|
|
msg = &f.parameterStatus
|
|
case 't':
|
|
msg = &f.parameterDescription
|
|
case 'T':
|
|
msg = &f.rowDescription
|
|
case 'V':
|
|
msg = &f.functionCallResponse
|
|
case 'W':
|
|
msg = &f.copyBothResponse
|
|
case 'Z':
|
|
msg = &f.readyForQuery
|
|
default:
|
|
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
|
|
}
|
|
|
|
err = msg.Decode(msgBody)
|
|
return msg, err
|
|
}
|
|
|
|
// Authentication message type constants.
|
|
// See src/include/libpq/pqcomm.h for all
|
|
// constants.
|
|
const (
|
|
AuthTypeOk = 0
|
|
AuthTypeCleartextPassword = 3
|
|
AuthTypeMD5Password = 5
|
|
AuthTypeSCMCreds = 6
|
|
AuthTypeGSS = 7
|
|
AuthTypeGSSCont = 8
|
|
AuthTypeSSPI = 9
|
|
AuthTypeSASL = 10
|
|
AuthTypeSASLContinue = 11
|
|
AuthTypeSASLFinal = 12
|
|
)
|
|
|
|
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
|
|
if len(src) < 4 {
|
|
return nil, errors.New("authentication message too short")
|
|
}
|
|
f.authType = binary.BigEndian.Uint32(src[:4])
|
|
|
|
switch f.authType {
|
|
case AuthTypeOk:
|
|
return &f.authenticationOk, nil
|
|
case AuthTypeCleartextPassword:
|
|
return &f.authenticationCleartextPassword, nil
|
|
case AuthTypeMD5Password:
|
|
return &f.authenticationMD5Password, nil
|
|
case AuthTypeSCMCreds:
|
|
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
|
|
case AuthTypeGSS:
|
|
return &f.authenticationGSS, nil
|
|
case AuthTypeGSSCont:
|
|
return &f.authenticationGSSContinue, nil
|
|
case AuthTypeSSPI:
|
|
return nil, errors.New("AuthTypeSSPI is unimplemented")
|
|
case AuthTypeSASL:
|
|
return &f.authenticationSASL, nil
|
|
case AuthTypeSASLContinue:
|
|
return &f.authenticationSASLContinue, nil
|
|
case AuthTypeSASLFinal:
|
|
return &f.authenticationSASLFinal, nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
|
|
}
|
|
}
|
|
|
|
// GetAuthType returns the authType used in the current state of the frontend.
|
|
// See SetAuthType for more information.
|
|
func (f *Frontend) GetAuthType() uint32 {
|
|
return f.authType
|
|
}
|