diff --git a/conn.go b/conn.go index 22e5f455..87bdff5d 100644 --- a/conn.go +++ b/conn.go @@ -593,7 +593,10 @@ type RowReader struct{} // TODO - Read*... func (rr *RowReader) ReadInt32(qr *QueryResult) int32 { - fd, size := qr.NextColumn() + fd, size, ok := qr.NextColumn() + if !ok { + return 0 + } // TODO - do something about nulls if size == -1 { @@ -604,7 +607,10 @@ func (rr *RowReader) ReadInt32(qr *QueryResult) int32 { } func (rr *RowReader) ReadInt64(qr *QueryResult) int64 { - fd, size := qr.NextColumn() + fd, size, ok := qr.NextColumn() + if !ok { + return 0 + } // TODO - do something about nulls if size == -1 { @@ -615,7 +621,12 @@ func (rr *RowReader) ReadInt64(qr *QueryResult) int64 { } func (rr *RowReader) ReadTime(qr *QueryResult) time.Time { - fd, size := qr.NextColumn() + var zeroTime time.Time + + fd, size, ok := qr.NextColumn() + if !ok { + return zeroTime + } // TODO - do something about nulls if size == -1 { @@ -626,7 +637,12 @@ func (rr *RowReader) ReadTime(qr *QueryResult) time.Time { } func (rr *RowReader) ReadDate(qr *QueryResult) time.Time { - fd, size := qr.NextColumn() + var zeroTime time.Time + + fd, size, ok := qr.NextColumn() + if !ok { + return zeroTime + } // TODO - do something about nulls if size == -1 { @@ -637,12 +653,19 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time { } func (rr *RowReader) ReadString(qr *QueryResult) string { - _, size := qr.NextColumn() + _, size, ok := qr.NextColumn() + if !ok { + return "" + } + return qr.mr.ReadString(size) } func (rr *RowReader) ReadValue(qr *QueryResult) interface{} { - fd, size := qr.NextColumn() + fd, size, ok := qr.NextColumn() + if !ok { + return nil + } if size > -1 { if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil { @@ -768,12 +791,20 @@ func (qr *QueryResult) NextRow() bool { } } -func (qr *QueryResult) NextColumn() (*FieldDescription, int32) { +func (qr *QueryResult) NextColumn() (*FieldDescription, int32, bool) { + if qr.closed { + return nil, 0, false + } + if len(qr.fields) <= qr.columnIdx { + qr.Fatal(ProtocolError("No next column available")) + return nil, 0, false + } + fd := &qr.fields[qr.columnIdx] qr.columnIdx++ size := qr.mr.ReadInt32() - return fd, size + return fd, size, true } // TODO - document diff --git a/conn_test.go b/conn_test.go index e999214c..266a8bbe 100644 --- a/conn_test.go +++ b/conn_test.go @@ -309,6 +309,159 @@ func TestConnQuery(t *testing.T) { } } +// Do a simple query to ensure the connection is still usable +func ensureConnValid(t *testing.T, conn *pgx.Conn) { + var sum, rowCount int32 + + qr, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + defer qr.Close() + + for qr.NextRow() { + var rr pgx.RowReader + sum += rr.ReadInt32(qr) + rowCount++ + } + + if qr.Err() != nil { + t.Fatalf("conn.Query failed: ", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } + if sum != 55 { + t.Error("Wrong values returned") + } +} + +// Test that a connection stays valid when query results are closed early +func TestConnQueryCloseEarly(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Immediately close query without reading any rows + qr, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + qr.Close() + + ensureConnValid(t, conn) + + // Read partial response then close + qr, err = conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + ok := qr.NextRow() + if !ok { + t.Fatal("qr.NextRow terminated early") + } + + var rr pgx.RowReader + if n := rr.ReadInt32(qr); n != 1 { + t.Fatalf("Expected 1 from first row, but got %v", n) + } + + qr.Close() + + ensureConnValid(t, conn) +} + +// Test that a connection stays valid when query results read incorrectly +func TestConnQueryReadWrongTypeError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Read a single value incorrectly + qr, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + rowsRead := 0 + + for qr.NextRow() { + var rr pgx.RowReader + rr.ReadDate(qr) + rowsRead++ + } + + if rowsRead != 1 { + t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) + } + + if qr.Err() == nil { + t.Fatal("Expected QueryResult to have an error after an improper read but it didn't") + } + + // Read too many values + qr, err = conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + rowsRead = 0 + + for qr.NextRow() { + var rr pgx.RowReader + rr.ReadInt32(qr) + rr.ReadInt32(qr) + rowsRead++ + } + + if rowsRead != 1 { + t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) + } + + if qr.Err() == nil { + t.Fatal("Expected QueryResult to have an error after an improper read but it didn't") + } + + ensureConnValid(t, conn) +} + +// Test that a connection stays valid when query results read incorrectly +func TestConnQueryReadTooManyValues(t *testing.T) { + // t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Read too many values + qr, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + rowsRead := 0 + + for qr.NextRow() { + var rr pgx.RowReader + rr.ReadInt32(qr) + rr.ReadInt32(qr) + rowsRead++ + } + + if rowsRead != 1 { + t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) + } + + if qr.Err() == nil { + t.Fatal("Expected QueryResult to have an error after an improper read but it didn't") + } + + ensureConnValid(t, conn) +} + func TestConnectionSelectValue(t *testing.T) { t.Parallel()