2
0

Make Chunkreader an internal implementation detail

This commit is contained in:
Jack Christensen
2022-02-26 08:50:46 -06:00
parent d13f651810
commit 2e0ec225de
14 changed files with 124 additions and 203 deletions
+7 -7
View File
@@ -8,7 +8,7 @@ import (
// Backend acts as a server for the PostgreSQL wire protocol version 3.
type Backend struct {
cr ChunkReader
cr *chunkReader
w io.Writer
// Frontend message flyweights
@@ -30,11 +30,10 @@ type Backend struct {
sync Sync
terminate Terminate
bodyLen int
msgType byte
partialMsg bool
authType uint32
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
const (
@@ -43,7 +42,8 @@ const (
)
// NewBackend creates a new Backend.
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
func NewBackend(r io.Reader, w io.Writer) *Backend {
cr := newChunkReader(r, 0)
return &Backend{cr: cr, w: w}
}
+4 -4
View File
@@ -16,7 +16,7 @@ func TestBackendReceiveInterrupted(t *testing.T) {
server := &interruptReader{}
server.push([]byte{'Q', 0, 0, 0, 6})
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
backend := pgproto3.NewBackend(server, nil)
msg, err := backend.Receive()
if err == nil {
@@ -43,7 +43,7 @@ func TestBackendReceiveUnexpectedEOF(t *testing.T) {
server := &interruptReader{}
server.push([]byte{'Q', 0, 0, 0, 6})
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
backend := pgproto3.NewBackend(server, nil)
// Receive regular msg
msg, err := backend.Receive()
@@ -77,7 +77,7 @@ func TestStartupMessage(t *testing.T) {
server := &interruptReader{}
server.push(dst)
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
backend := pgproto3.NewBackend(server, nil)
msg, err := backend.ReceiveStartupMessage()
require.NoError(t, err)
@@ -110,7 +110,7 @@ func TestStartupMessage(t *testing.T) {
dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
server.push(dst)
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
backend := pgproto3.NewBackend(server, nil)
msg, err := backend.ReceiveStartupMessage()
require.Error(t, err)
+80 -10
View File
@@ -2,18 +2,88 @@ package pgproto3
import (
"io"
"github.com/jackc/pgx/v5/chunkreader"
)
// ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package.
type ChunkReader interface {
// Next returns buf filled with the next n bytes. If an error (including a partial read) occurs,
// buf must be nil. Next must preserve any partially read data. Next must not reuse buf.
Next(n int) (buf []byte, err error)
// chunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and
// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually
// requested. The memory returned via Next is owned by the caller. This avoids the need for an additional copy.
//
// The downside of this approach is that a large buffer can be pinned in memory even if only a small slice is
// referenced. For example, an entire 4096 byte block could be pinned in memory by even a 1 byte slice. In these rare
// cases it would be advantageous to copy the bytes to another slice.
type chunkReader struct {
r io.Reader
buf []byte
rp, wp int // buf read position and write position
minBufLen int
}
// NewChunkReader creates and returns a new default ChunkReader.
func NewChunkReader(r io.Reader) ChunkReader {
return chunkreader.New(r)
// newChunkReader creates and returns a new chunkReader for r with default configuration with minBufSize internal buffer.
// If bufSize is <= 0 it uses a default value.
func newChunkReader(r io.Reader, minBufSize int) *chunkReader {
if minBufSize <= 0 {
// By historical reasons Postgres currently has 8KB send buffer inside,
// so here we want to have at least the same size buffer.
// @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134
// @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru
minBufSize = 8192
}
return &chunkReader{
r: r,
buf: make([]byte, minBufSize),
minBufLen: minBufSize,
}
}
// Next returns buf filled with the next n bytes. The caller gains ownership of buf. It is not necessary to make a copy
// of buf. If an error occurs, buf will be nil.
func (r *chunkReader) Next(n int) (buf []byte, err error) {
// n bytes already in buf
if (r.wp - r.rp) >= n {
buf = r.buf[r.rp : r.rp+n]
r.rp += n
return buf, err
}
// available space in buf is less than n
if len(r.buf) < n {
r.copyBufContents(r.newBuf(n))
}
// buf is large enough, but need to shift filled area to start to make enough contiguous space
minReadCount := n - (r.wp - r.rp)
if (len(r.buf) - r.wp) < minReadCount {
newBuf := r.newBuf(n)
r.copyBufContents(newBuf)
}
if err := r.appendAtLeast(minReadCount); err != nil {
return nil, err
}
buf = r.buf[r.rp : r.rp+n]
r.rp += n
return buf, nil
}
func (r *chunkReader) appendAtLeast(fillLen int) error {
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
r.wp += n
return err
}
func (r *chunkReader) newBuf(size int) []byte {
if size < r.minBufLen {
size = r.minBufLen
}
return make([]byte, size)
}
func (r *chunkReader) copyBufContents(dest []byte) {
r.wp = copy(dest, r.buf[r.rp:r.wp])
r.rp = 0
r.buf = dest
}
+116
View File
@@ -0,0 +1,116 @@
package pgproto3
import (
"bytes"
"math/rand"
"testing"
)
func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
server := &bytes.Buffer{}
r := newChunkReader(server, 4)
src := []byte{1, 2, 3, 4}
server.Write(src)
n1, err := r.Next(2)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n1, src[0:2]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
}
n2, err := r.Next(2)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n2, src[2:4]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
}
if bytes.Compare(r.buf, src) != 0 {
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
}
if r.rp != 4 {
t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp)
}
if r.wp != 4 {
t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp)
}
}
func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) {
server := &bytes.Buffer{}
r := newChunkReader(server, 4)
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
server.Write(src)
n1, err := r.Next(5)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n1, src[0:5]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1)
}
if len(r.buf) != 5 {
t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf))
}
}
func TestChunkReaderDoesNotReuseBuf(t *testing.T) {
server := &bytes.Buffer{}
r := newChunkReader(server, 4)
src := []byte{1, 2, 3, 4, 5, 6, 7, 8}
server.Write(src)
n1, err := r.Next(4)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n1, src[0:4]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1)
}
n2, err := r.Next(4)
if err != nil {
t.Fatal(err)
}
if bytes.Compare(n2, src[4:8]) != 0 {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2)
}
if bytes.Compare(n1, src[0:4]) != 0 {
t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1)
}
}
type randomReader struct {
rnd *rand.Rand
}
// Read reads a random number of random bytes.
func (r *randomReader) Read(p []byte) (n int, err error) {
n = r.rnd.Intn(len(p) + 1)
return r.rnd.Read(p[:n])
}
func TestChunkReaderNextFuzz(t *testing.T) {
rr := &randomReader{rnd: rand.New(rand.NewSource(1))}
r := newChunkReader(rr, 8192)
randomSizes := rand.New(rand.NewSource(0))
for i := 0; i < 100000; i++ {
size := randomSizes.Intn(16384) + 1
buf, err := r.Next(size)
if err != nil {
t.Fatal(err)
}
if len(buf) != size {
t.Fatalf("Expected to get %v bytes but got %v bytes", size, len(buf))
}
}
}
+1 -1
View File
@@ -14,7 +14,7 @@ type PgFortuneBackend struct {
}
func NewPgFortuneBackend(conn net.Conn, responder func() ([]byte, error)) *PgFortuneBackend {
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)
backend := pgproto3.NewBackend(conn, conn)
connHandler := &PgFortuneBackend{
backend: backend,
+3 -2
View File
@@ -9,7 +9,7 @@ import (
// Frontend acts as a client for the PostgreSQL wire protocol version 3.
type Frontend struct {
cr ChunkReader
cr *chunkReader
w io.Writer
// Backend message flyweights
@@ -49,7 +49,8 @@ type Frontend struct {
}
// NewFrontend creates a new Frontend.
func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
func NewFrontend(r io.Reader, w io.Writer) *Frontend {
cr := newChunkReader(r, 0)
return &Frontend{cr: cr, w: w}
}
+3 -3
View File
@@ -38,7 +38,7 @@ func TestFrontendReceiveInterrupted(t *testing.T) {
server := &interruptReader{}
server.push([]byte{'Z', 0, 0, 0, 5})
frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil)
frontend := pgproto3.NewFrontend(server, nil)
msg, err := frontend.Receive()
if err == nil {
@@ -65,7 +65,7 @@ func TestFrontendReceiveUnexpectedEOF(t *testing.T) {
server := &interruptReader{}
server.push([]byte{'Z', 0, 0, 0, 5})
frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil)
frontend := pgproto3.NewFrontend(server, nil)
msg, err := frontend.Receive()
if err == nil {
@@ -109,7 +109,7 @@ func TestErrorResponse(t *testing.T) {
server := &interruptReader{}
server.push(raw)
frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil)
frontend := pgproto3.NewFrontend(server, nil)
got, err := frontend.Receive()
require.NoError(t, err)