diff --git a/conn.go b/conn.go index d7ecff0e..e33fba8d 100644 --- a/conn.go +++ b/conn.go @@ -495,7 +495,23 @@ func (qr *QueryResult) Err() error { return qr.err } +// abort signals that the query was not successfully sent to the server. +// This differs from Fatal in that it is not necessary to readUntilReadyForQuery +func (qr *QueryResult) abort(err error) { + if qr.err != nil { + return + } + + qr.err = err + qr.close() +} + +// Fatal signals an error occurred after the query was sent to the server func (qr *QueryResult) Fatal(err error) { + if qr.err != nil { + return + } + qr.err = err qr.Close() } @@ -647,19 +663,18 @@ func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) { c.qr = QueryResult{conn: c} qr := &c.qr - // TODO - shouldn't be messing with qr.err and qr.closed directly if ps, present := c.preparedStatements[sql]; present { qr.fields = ps.FieldDescriptions - qr.err = c.sendPreparedQuery(ps, args...) - if qr.err != nil { - qr.closed = true + err := c.sendPreparedQuery(ps, args...) + if err != nil { + qr.abort(err) } return qr, qr.err } - qr.err = c.sendSimpleQuery(sql, args...) - if qr.err != nil { - qr.closed = true + err := c.sendSimpleQuery(sql, args...) + if err != nil { + qr.abort(err) return qr, qr.err } @@ -668,8 +683,7 @@ func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) { for { t, r, err := c.rxMsg() if err != nil { - qr.err = err - qr.closed = true + qr.Fatal(err) return qr, qr.err } @@ -680,8 +694,7 @@ func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) { default: err = qr.conn.processContextFreeMsg(t, r) if err != nil { - qr.closed = true - qr.err = err + qr.Fatal(err) return qr, qr.err } } diff --git a/conn_test.go b/conn_test.go index 4a5e3a29..f94dfed8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,6 +1,8 @@ package pgx_test import ( + "bytes" + "fmt" "github.com/jackc/pgx" "strings" "sync" @@ -891,15 +893,199 @@ func TestCommandTag(t *testing.T) { } } -func TestQueryRowError(t *testing.T) { +func TestQueryRowCoreTypes(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - var n int32 - err := conn.QueryRow("SYNTAX ERROR").Scan(&n) - if _, ok := err.(pgx.PgError); !ok { - t.Fatalf("Expected to receive PgError, but instead received: %v", err) + type allTypes struct { + s string + i16 int16 + i32 int32 + i64 int64 + f32 float32 + f64 float64 + b bool + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + expected allTypes + }{ + {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}}, + {"select $1::int2", []interface{}{int16(42)}, []interface{}{&actual.i16}, allTypes{i16: 42}}, + {"select $1::int4", []interface{}{int32(42)}, []interface{}{&actual.i32}, allTypes{i32: 42}}, + {"select $1::int8", []interface{}{int64(42)}, []interface{}{&actual.i64}, allTypes{i64: 42}}, + {"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}}, + {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}}, + {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, + } + + for i, tt := range tests { + psName := fmt.Sprintf("success%d", i) + mustPrepare(t, conn, psName, tt.sql) + + for _, sql := range []string{tt.sql, psName} { + actual = zero + + err := conn.QueryRow(sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, sql, tt.queryArgs) + } + + if actual != tt.expected { + t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } + } +} + +func TestQueryRowCoreBytea(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []byte + sql := "select $1::bytea" + queryArg := []byte{0, 15, 255, 17} + expected := []byte{0, 15, 255, 17} + + psName := "selectBytea" + mustPrepare(t, conn, psName, sql) + + for _, sql := range []string{sql, psName} { + actual = nil + + err := conn.QueryRow(sql, queryArg).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } + + if bytes.Compare(actual, expected) != 0 { + t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryRowUnpreparedErrors(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + s string + i16 int16 + i32 int32 + i64 int64 + f32 float32 + f64 float64 + b bool + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + err string + }{ + {"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Expected type oid 21 but received type oid 705"}, + {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, + {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, + } + + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err == nil { + t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryRowPreparedErrors(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + s string + i16 int16 + i32 int32 + i64 int64 + f32 float32 + f64 float64 + b bool + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + err string + }{ + {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Expected type oid 21 but received type oid 25"}, + } + + for i, tt := range tests { + psName := fmt.Sprintf("ps%d", i) + mustPrepare(t, conn, psName, tt.sql) + + actual = zero + + err := conn.QueryRow(psName, tt.queryArgs...).Scan(tt.scanArgs...) + if err == nil { + t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryPreparedEncodeError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustPrepare(t, conn, "testTranscode", "select $1::integer") + defer func() { + if err := conn.Deallocate("testTranscode"); err != nil { + t.Fatalf("Unable to deallocate prepared statement: %v", err) + } + }() + + _, err := conn.Query("testTranscode", "wrong") + switch { + case err == nil: + t.Error("Expected transcode error to return error, but it didn't") + case err.Error() == "Expected integer representable in int32, received string wrong": + // Correct behavior + default: + t.Errorf("Expected transcode error, received %v", err) } } diff --git a/values.go b/values.go index f413e3de..f350d69c 100644 --- a/values.go +++ b/values.go @@ -209,7 +209,7 @@ func encodeBool(w *WriteBuf, value interface{}) error { func decodeInt8(qr *QueryResult, fd *FieldDescription, size int32) int64 { if fd.DataType != Int8Oid { - qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read %v but received: %v", Int8Oid, fd.DataType))) + qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, fd.DataType))) return 0 } @@ -270,7 +270,7 @@ func encodeInt8(w *WriteBuf, value interface{}) error { func decodeInt2(qr *QueryResult, fd *FieldDescription, size int32) int16 { if fd.DataType != Int2Oid { - qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read %v but received: %v", Int2Oid, fd.DataType))) + qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, fd.DataType))) return 0 } @@ -346,7 +346,7 @@ func encodeInt2(w *WriteBuf, value interface{}) error { func decodeInt4(qr *QueryResult, fd *FieldDescription, size int32) int32 { if fd.DataType != Int4Oid { - qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read %v but received: %v", Int4Oid, fd.DataType))) + qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, fd.DataType))) return 0 } @@ -530,6 +530,8 @@ func decodeBytea(qr *QueryResult, fd *FieldDescription, size int32) []byte { return nil } return b + case BinaryFormatCode: + return qr.mr.ReadBytes(size) default: qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) return nil @@ -552,7 +554,7 @@ func decodeDate(qr *QueryResult, fd *FieldDescription, size int32) time.Time { var zeroTime time.Time if fd.DataType != DateOid { - qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read date but received: %v", fd.DataType))) + qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, fd.DataType))) return zeroTime } @@ -591,7 +593,7 @@ func decodeTimestampTz(qr *QueryResult, fd *FieldDescription, size int32) time.T var zeroTime time.Time if fd.DataType != TimestampTzOid { - qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read timestamptz but received: %v", fd.DataType))) + qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, fd.DataType))) return zeroTime } diff --git a/values_test.go b/values_test.go index e9a6a21e..2285ad5d 100644 --- a/values_test.go +++ b/values_test.go @@ -78,30 +78,6 @@ func TestSanitizeSql(t *testing.T) { } } -func TestEncodeError(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustPrepare(t, conn, "testTranscode", "select $1::integer") - defer func() { - if err := conn.Deallocate("testTranscode"); err != nil { - t.Fatalf("Unable to deallocate prepared statement: %v", err) - } - }() - - _, err := conn.Query("testTranscode", "wrong") - switch { - case err == nil: - t.Error("Expected transcode error to return error, but it didn't") - case err.Error() == "Expected integer representable in int32, received string wrong": - // Correct behavior - default: - t.Errorf("Expected transcode error, received %v", err) - } -} - // TODO func TestNilTranscode(t *testing.T) { // t.Parallel()