diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go index f9d6555c..f8d437b2 100644 --- a/chunkreader/chunkreader.go +++ b/chunkreader/chunkreader.go @@ -9,14 +9,12 @@ type ChunkReader struct { buf []byte rp, wp int // buf read position and write position - taken bool options Options } type Options struct { MinBufLen int // Minimum buffer length - BlockLen int // Increments to expand buffer (e.g. a 8000 byte request with a BlockLen of 1024 would yield a buffer len of 8192) } func NewChunkReader(r io.Reader) *ChunkReader { @@ -32,9 +30,6 @@ func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { if options.MinBufLen == 0 { options.MinBufLen = 4096 } - if options.BlockLen == 0 { - options.BlockLen = 512 - } return &ChunkReader{ r: r, @@ -43,8 +38,8 @@ func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { }, nil } -// Next returns buf filled with the next n bytes. buf is only valid until the -// next call to Next. If an error occurs, buf will be nil. +// Next returns buf filled with the next n bytes. If an error occurs, buf will +// be nil. func (r *ChunkReader) Next(n int) (buf []byte, err error) { // n bytes already in buf if (r.wp - r.rp) >= n { @@ -56,17 +51,12 @@ func (r *ChunkReader) Next(n int) (buf []byte, err error) { // available space in buf is less than n if len(r.buf) < n { r.copyBufContents(r.newBuf(n)) - r.taken = false } // buf is large enough, but need to shift filled area to start to make enough contiguous space minReadCount := n - (r.wp - r.rp) if (len(r.buf) - r.wp) < minReadCount { - newBuf := r.buf - if r.taken { - newBuf = r.newBuf(n) - r.taken = false - } + newBuf := r.newBuf(n) r.copyBufContents(newBuf) } @@ -79,20 +69,13 @@ func (r *ChunkReader) Next(n int) (buf []byte, err error) { return buf, nil } -// KeepLast prevents the last data retrieved by Next from being reused by the -// ChunkReader. -func (r *ChunkReader) KeepLast() { - r.taken = true -} - func (r *ChunkReader) appendAtLeast(fillLen int) error { n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) r.wp += n return err } -func (r *ChunkReader) newBuf(min int) []byte { - size := ((min / r.options.BlockLen) + 1) * r.options.BlockLen +func (r *ChunkReader) newBuf(size int) []byte { if size < r.options.MinBufLen { size = r.options.MinBufLen } diff --git a/chunkreader/chunkreader_test.go b/chunkreader/chunkreader_test.go index 9c19ff4a..3be07e3c 100644 --- a/chunkreader/chunkreader_test.go +++ b/chunkreader/chunkreader_test.go @@ -7,7 +7,7 @@ import ( func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2}) + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -44,7 +44,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2}) + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -59,14 +59,14 @@ func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { if bytes.Compare(n1, src[0:5]) != 0 { t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) } - if len(r.buf) != 6 { - t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 6, len(r.buf)) + if len(r.buf) != 5 { + t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf)) } } -func TestChunkReaderNextReusesBuf(t *testing.T) { +func TestChunkReaderDoesNotReuseBuf(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1}) + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -90,38 +90,6 @@ func TestChunkReaderNextReusesBuf(t *testing.T) { t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) } - if bytes.Compare(n1, src[4:8]) != 0 { - t.Fatalf("Expected Next to have reused buf, %v found instead of %v", src[4:8], n1) - } -} - -func TestChunkReaderKeepLastPreventsBufReuse(t *testing.T) { - server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1}) - if err != nil { - t.Fatal(err) - } - - src := []byte{1, 2, 3, 4, 5, 6, 7, 8} - server.Write(src) - - n1, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:4]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) - } - r.KeepLast() - - n2, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n2, src[4:8]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) - } - if bytes.Compare(n1, src[0:4]) != 0 { t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1) } diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go index e265a247..54f4978f 100644 --- a/pgproto3/authentication.go +++ b/pgproto3/authentication.go @@ -21,7 +21,7 @@ type Authentication struct { func (*Authentication) Backend() {} -func (dst *Authentication) UnmarshalBinary(src []byte) error { +func (dst *Authentication) Decode(src []byte) error { *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} switch dst.Type { diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go index 5d8eb496..04f31aec 100644 --- a/pgproto3/backend_key_data.go +++ b/pgproto3/backend_key_data.go @@ -13,7 +13,7 @@ type BackendKeyData struct { func (*BackendKeyData) Backend() {} -func (dst *BackendKeyData) UnmarshalBinary(src []byte) error { +func (dst *BackendKeyData) Decode(src []byte) error { if len(src) != 8 { return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} } diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go index 756a30e6..4f1c44b8 100644 --- a/pgproto3/bind_complete.go +++ b/pgproto3/bind_complete.go @@ -8,7 +8,7 @@ type BindComplete struct{} func (*BindComplete) Backend() {} -func (dst *BindComplete) UnmarshalBinary(src []byte) error { +func (dst *BindComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go index fd6ff180..9bab3e8c 100644 --- a/pgproto3/close_complete.go +++ b/pgproto3/close_complete.go @@ -8,7 +8,7 @@ type CloseComplete struct{} func (*CloseComplete) Backend() {} -func (dst *CloseComplete) UnmarshalBinary(src []byte) error { +func (dst *CloseComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go index ac60153e..86653804 100644 --- a/pgproto3/command_complete.go +++ b/pgproto3/command_complete.go @@ -11,14 +11,13 @@ type CommandComplete struct { func (*CommandComplete) Backend() {} -func (dst *CommandComplete) UnmarshalBinary(src []byte) error { - buf := bytes.NewBuffer(src) - - b, err := buf.ReadBytes(0) - if err != nil { - return err +func (dst *CommandComplete) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CommandComplete"} } - dst.CommandTag = string(b[:len(b)-1]) + + dst.CommandTag = string(src[:idx]) return nil } diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index 2a4c58af..3857c187 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -13,7 +13,7 @@ type CopyBothResponse struct { func (*CopyBothResponse) Backend() {} -func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error { +func (dst *CopyBothResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 3 { diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go index b9ea6272..de7ab4ff 100644 --- a/pgproto3/copy_data.go +++ b/pgproto3/copy_data.go @@ -13,9 +13,8 @@ type CopyData struct { func (*CopyData) Backend() {} func (*CopyData) Frontend() {} -func (dst *CopyData) UnmarshalBinary(src []byte) error { - dst.Data = make([]byte, len(src)) - copy(dst.Data, src) +func (dst *CopyData) Decode(src []byte) error { + dst.Data = src return nil } diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index 63868c7a..9854d665 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -13,7 +13,7 @@ type CopyInResponse struct { func (*CopyInResponse) Backend() {} -func (dst *CopyInResponse) UnmarshalBinary(src []byte) error { +func (dst *CopyInResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 3 { diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index e46d9e8f..5ef6e4c1 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -13,7 +13,7 @@ type CopyOutResponse struct { func (*CopyOutResponse) Backend() {} -func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error { +func (dst *CopyOutResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 3 { diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index c95861b9..6b27f728 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -13,35 +13,42 @@ type DataRow struct { func (*DataRow) Backend() {} -func (dst *DataRow) UnmarshalBinary(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 2 { +func (dst *DataRow) Decode(src []byte) error { + if len(src) < 2 { return &invalidMessageFormatErr{messageType: "DataRow"} } - fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + rp := 0 + fieldCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 - dst.Values = make([][]byte, fieldCount) + // If the capacity of the values slice is too small OR substantially too + // large reallocate. This is too avoid one row with many columns from + // permanently allocating memory. + if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { + dst.Values = make([][]byte, fieldCount, 32) + } else { + dst.Values = dst.Values[:fieldCount] + } for i := 0; i < fieldCount; i++ { - if buf.Len() < 4 { + if len(src[rp:]) < 4 { return &invalidMessageFormatErr{messageType: "DataRow"} } - msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4)))) + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 // null if msgSize == -1 { - continue - } + dst.Values[i] = nil + } else { + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "DataRow"} + } - value := make([]byte, msgSize) - _, err := buf.Read(value) - if err != nil { - return err + dst.Values[i] = src[rp : rp+msgSize] + rp += msgSize } - - dst.Values[i] = value } return nil diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go index de6e6272..13ed1886 100644 --- a/pgproto3/empty_query_response.go +++ b/pgproto3/empty_query_response.go @@ -8,7 +8,7 @@ type EmptyQueryResponse struct{} func (*EmptyQueryResponse) Backend() {} -func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error { +func (dst *EmptyQueryResponse) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go index 82e408d7..602dd2a1 100644 --- a/pgproto3/error_response.go +++ b/pgproto3/error_response.go @@ -30,7 +30,7 @@ type ErrorResponse struct { func (*ErrorResponse) Backend() {} -func (dst *ErrorResponse) UnmarshalBinary(src []byte) error { +func (dst *ErrorResponse) Decode(src []byte) error { *dst = ErrorResponse{} buf := bytes.NewBuffer(src) diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index df67b718..50835836 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -108,6 +108,6 @@ func (b *Frontend) Receive() (BackendMessage, error) { return nil, err } - err = msg.UnmarshalBinary(msgBody) + err = msg.Decode(msgBody) return msg, err } diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go index 5c692b36..1e0f16af 100644 --- a/pgproto3/function_call_response.go +++ b/pgproto3/function_call_response.go @@ -13,20 +13,24 @@ type FunctionCallResponse struct { func (*FunctionCallResponse) Backend() {} -func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 4 { +func (dst *FunctionCallResponse) Decode(src []byte) error { + if len(src) < 4 { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} } - resultSize := int(binary.BigEndian.Uint32(buf.Next(4))) - if buf.Len() != resultSize { + rp := 0 + resultSize := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if resultSize == -1 { + dst.Result = nil + return nil + } + + if len(src[rp:]) != resultSize { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} } - dst.Result = make([]byte, resultSize) - copy(dst.Result, buf.Bytes()) - + dst.Result = src[rp:] return nil } diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go index 47ebf28e..3adec4ad 100644 --- a/pgproto3/no_data.go +++ b/pgproto3/no_data.go @@ -8,7 +8,7 @@ type NoData struct{} func (*NoData) Backend() {} -func (dst *NoData) UnmarshalBinary(src []byte) error { +func (dst *NoData) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go index 767c9a67..8af55baf 100644 --- a/pgproto3/notice_response.go +++ b/pgproto3/notice_response.go @@ -4,8 +4,8 @@ type NoticeResponse ErrorResponse func (*NoticeResponse) Backend() {} -func (dst *NoticeResponse) UnmarshalBinary(src []byte) error { - return (*ErrorResponse)(dst).UnmarshalBinary(src) +func (dst *NoticeResponse) Decode(src []byte) error { + return (*ErrorResponse)(dst).Decode(src) } func (src *NoticeResponse) MarshalBinary() ([]byte, error) { diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go index 4ae8bab3..7262844e 100644 --- a/pgproto3/notification_response.go +++ b/pgproto3/notification_response.go @@ -14,7 +14,7 @@ type NotificationResponse struct { func (*NotificationResponse) Backend() {} -func (dst *NotificationResponse) UnmarshalBinary(src []byte) error { +func (dst *NotificationResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) pid := binary.BigEndian.Uint32(buf.Next(4)) diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go index 40d92c50..32b6e1c1 100644 --- a/pgproto3/parameter_description.go +++ b/pgproto3/parameter_description.go @@ -12,7 +12,7 @@ type ParameterDescription struct { func (*ParameterDescription) Backend() {} -func (dst *ParameterDescription) UnmarshalBinary(src []byte) error { +func (dst *ParameterDescription) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 2 { diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go index b8ce7f8d..9b10824c 100644 --- a/pgproto3/parameter_status.go +++ b/pgproto3/parameter_status.go @@ -13,7 +13,7 @@ type ParameterStatus struct { func (*ParameterStatus) Backend() {} -func (dst *ParameterStatus) UnmarshalBinary(src []byte) error { +func (dst *ParameterStatus) Decode(src []byte) error { buf := bytes.NewBuffer(src) b, err := buf.ReadBytes(0) diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go index 24951e3d..e949c14c 100644 --- a/pgproto3/parse_complete.go +++ b/pgproto3/parse_complete.go @@ -8,7 +8,7 @@ type ParseComplete struct{} func (*ParseComplete) Backend() {} -func (dst *ParseComplete) UnmarshalBinary(src []byte) error { +func (dst *ParseComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index a9221239..3fe8fc93 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -2,8 +2,13 @@ package pgproto3 import "fmt" +// Message is the interface implemented by an object that can decode and encode +// a particular PostgreSQL message. +// +// Decode is allowed and expected to retain a reference to data after +// returning (unlike encoding.BinaryUnmarshaler). type Message interface { - UnmarshalBinary(data []byte) error + Decode(data []byte) error MarshalBinary() (data []byte, err error) } @@ -17,58 +22,6 @@ type BackendMessage interface { Backend() // no-op method to distinguish frontend from backend methods } -// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) { -// switch typeByte { -// case '1': -// return ParseParseComplete(body) -// case '2': -// return ParseBindComplete(body) -// case 'C': -// return ParseCommandComplete(body) -// case 'D': -// return ParseDataRow(body) -// case 'E': -// return ParseErrorResponse(body) -// case 'K': -// return ParseBackendKeyData(body) -// case 'R': -// return ParseAuthentication(body) -// case 'S': -// return ParseParameterStatus(body) -// case 'T': -// return ParseRowDescription(body) -// case 't': -// return ParseParameterDescription(body) -// case 'Z': -// return ParseReadyForQuery(body) -// default: -// return ParseUnknownMessage(typeByte, body) -// } -// } - -// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) { -// switch typeByte { -// case 'B': -// return ParseBind(body) -// case 'D': -// return ParseDescribe(body) -// case 'E': -// return ParseExecute(body) -// case 'P': -// return ParseParse(body) -// case 'p': -// return ParsePasswordMessage(body) -// case 'Q': -// return ParseQuery(body) -// case 'S': -// return ParseSync(body) -// case 'X': -// return ParseTerminate(body) -// default: -// return ParseUnknownMessage(typeByte, body) -// } -// } - type invalidMessageLenErr struct { messageType string expectedLen int diff --git a/pgproto3/query.go b/pgproto3/query.go index a3fc32eb..b5fc2dbc 100644 --- a/pgproto3/query.go +++ b/pgproto3/query.go @@ -11,7 +11,7 @@ type Query struct { func (*Query) Frontend() {} -func (dst *Query) UnmarshalBinary(src []byte) error { +func (dst *Query) Decode(src []byte) error { i := bytes.IndexByte(src, 0) if i != len(src)-1 { return &invalidMessageFormatErr{messageType: "Query"} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go index 09005d00..e0e4707a 100644 --- a/pgproto3/ready_for_query.go +++ b/pgproto3/ready_for_query.go @@ -10,7 +10,7 @@ type ReadyForQuery struct { func (*ReadyForQuery) Backend() {} -func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error { +func (dst *ReadyForQuery) Decode(src []byte) error { if len(src) != 1 { return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} } diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index 294a6aa9..b1110290 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -27,7 +27,7 @@ type RowDescription struct { func (*RowDescription) Backend() {} -func (dst *RowDescription) UnmarshalBinary(src []byte) error { +func (dst *RowDescription) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 2 {