More tests and bug fixes
This commit is contained in:
@@ -495,7 +495,23 @@ func (qr *QueryResult) Err() error {
|
|||||||
return qr.err
|
return qr.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// abort signals that the query was not successfully sent to the server.
|
||||||
|
// This differs from Fatal in that it is not necessary to readUntilReadyForQuery
|
||||||
|
func (qr *QueryResult) abort(err error) {
|
||||||
|
if qr.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
qr.err = err
|
||||||
|
qr.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fatal signals an error occurred after the query was sent to the server
|
||||||
func (qr *QueryResult) Fatal(err error) {
|
func (qr *QueryResult) Fatal(err error) {
|
||||||
|
if qr.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
qr.err = err
|
qr.err = err
|
||||||
qr.Close()
|
qr.Close()
|
||||||
}
|
}
|
||||||
@@ -647,19 +663,18 @@ func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) {
|
|||||||
c.qr = QueryResult{conn: c}
|
c.qr = QueryResult{conn: c}
|
||||||
qr := &c.qr
|
qr := &c.qr
|
||||||
|
|
||||||
// TODO - shouldn't be messing with qr.err and qr.closed directly
|
|
||||||
if ps, present := c.preparedStatements[sql]; present {
|
if ps, present := c.preparedStatements[sql]; present {
|
||||||
qr.fields = ps.FieldDescriptions
|
qr.fields = ps.FieldDescriptions
|
||||||
qr.err = c.sendPreparedQuery(ps, args...)
|
err := c.sendPreparedQuery(ps, args...)
|
||||||
if qr.err != nil {
|
if err != nil {
|
||||||
qr.closed = true
|
qr.abort(err)
|
||||||
}
|
}
|
||||||
return qr, qr.err
|
return qr, qr.err
|
||||||
}
|
}
|
||||||
|
|
||||||
qr.err = c.sendSimpleQuery(sql, args...)
|
err := c.sendSimpleQuery(sql, args...)
|
||||||
if qr.err != nil {
|
if err != nil {
|
||||||
qr.closed = true
|
qr.abort(err)
|
||||||
return qr, qr.err
|
return qr, qr.err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -668,8 +683,7 @@ func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) {
|
|||||||
for {
|
for {
|
||||||
t, r, err := c.rxMsg()
|
t, r, err := c.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
qr.err = err
|
qr.Fatal(err)
|
||||||
qr.closed = true
|
|
||||||
return qr, qr.err
|
return qr, qr.err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -680,8 +694,7 @@ func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) {
|
|||||||
default:
|
default:
|
||||||
err = qr.conn.processContextFreeMsg(t, r)
|
err = qr.conn.processContextFreeMsg(t, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
qr.closed = true
|
qr.Fatal(err)
|
||||||
qr.err = err
|
|
||||||
return qr, qr.err
|
return qr, qr.err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+191
-5
@@ -1,6 +1,8 @@
|
|||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -891,15 +893,199 @@ func TestCommandTag(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryRowError(t *testing.T) {
|
func TestQueryRowCoreTypes(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnect(t, *defaultConnConfig)
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
var n int32
|
type allTypes struct {
|
||||||
err := conn.QueryRow("SYNTAX ERROR").Scan(&n)
|
s string
|
||||||
if _, ok := err.(pgx.PgError); !ok {
|
i16 int16
|
||||||
t.Fatalf("Expected to receive PgError, but instead received: %v", err)
|
i32 int32
|
||||||
|
i64 int64
|
||||||
|
f32 float32
|
||||||
|
f64 float64
|
||||||
|
b bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual, zero allTypes
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
queryArgs []interface{}
|
||||||
|
scanArgs []interface{}
|
||||||
|
expected allTypes
|
||||||
|
}{
|
||||||
|
{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}},
|
||||||
|
{"select $1::int2", []interface{}{int16(42)}, []interface{}{&actual.i16}, allTypes{i16: 42}},
|
||||||
|
{"select $1::int4", []interface{}{int32(42)}, []interface{}{&actual.i32}, allTypes{i32: 42}},
|
||||||
|
{"select $1::int8", []interface{}{int64(42)}, []interface{}{&actual.i64}, allTypes{i64: 42}},
|
||||||
|
{"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}},
|
||||||
|
{"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}},
|
||||||
|
{"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
psName := fmt.Sprintf("success%d", i)
|
||||||
|
mustPrepare(t, conn, psName, tt.sql)
|
||||||
|
|
||||||
|
for _, sql := range []string{tt.sql, psName} {
|
||||||
|
actual = zero
|
||||||
|
|
||||||
|
err := conn.QueryRow(sql, tt.queryArgs...).Scan(tt.scanArgs...)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, sql, tt.queryArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual != tt.expected {
|
||||||
|
t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, sql, tt.queryArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowCoreBytea(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
var actual []byte
|
||||||
|
sql := "select $1::bytea"
|
||||||
|
queryArg := []byte{0, 15, 255, 17}
|
||||||
|
expected := []byte{0, 15, 255, 17}
|
||||||
|
|
||||||
|
psName := "selectBytea"
|
||||||
|
mustPrepare(t, conn, psName, sql)
|
||||||
|
|
||||||
|
for _, sql := range []string{sql, psName} {
|
||||||
|
actual = nil
|
||||||
|
|
||||||
|
err := conn.QueryRow(sql, queryArg).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Compare(actual, expected) != 0 {
|
||||||
|
t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowUnpreparedErrors(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
type allTypes struct {
|
||||||
|
s string
|
||||||
|
i16 int16
|
||||||
|
i32 int32
|
||||||
|
i64 int64
|
||||||
|
f32 float32
|
||||||
|
f64 float64
|
||||||
|
b bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual, zero allTypes
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
queryArgs []interface{}
|
||||||
|
scanArgs []interface{}
|
||||||
|
err string
|
||||||
|
}{
|
||||||
|
{"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Expected type oid 21 but received type oid 705"},
|
||||||
|
{"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
|
||||||
|
{"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
actual = zero
|
||||||
|
|
||||||
|
err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tt.err) {
|
||||||
|
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowPreparedErrors(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
type allTypes struct {
|
||||||
|
s string
|
||||||
|
i16 int16
|
||||||
|
i32 int32
|
||||||
|
i64 int64
|
||||||
|
f32 float32
|
||||||
|
f64 float64
|
||||||
|
b bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual, zero allTypes
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
queryArgs []interface{}
|
||||||
|
scanArgs []interface{}
|
||||||
|
err string
|
||||||
|
}{
|
||||||
|
{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Expected type oid 21 but received type oid 25"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
psName := fmt.Sprintf("ps%d", i)
|
||||||
|
mustPrepare(t, conn, psName, tt.sql)
|
||||||
|
|
||||||
|
actual = zero
|
||||||
|
|
||||||
|
err := conn.QueryRow(psName, tt.queryArgs...).Scan(tt.scanArgs...)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tt.err) {
|
||||||
|
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryPreparedEncodeError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustPrepare(t, conn, "testTranscode", "select $1::integer")
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Deallocate("testTranscode"); err != nil {
|
||||||
|
t.Fatalf("Unable to deallocate prepared statement: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := conn.Query("testTranscode", "wrong")
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
t.Error("Expected transcode error to return error, but it didn't")
|
||||||
|
case err.Error() == "Expected integer representable in int32, received string wrong":
|
||||||
|
// Correct behavior
|
||||||
|
default:
|
||||||
|
t.Errorf("Expected transcode error, received %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ func encodeBool(w *WriteBuf, value interface{}) error {
|
|||||||
|
|
||||||
func decodeInt8(qr *QueryResult, fd *FieldDescription, size int32) int64 {
|
func decodeInt8(qr *QueryResult, fd *FieldDescription, size int32) int64 {
|
||||||
if fd.DataType != Int8Oid {
|
if fd.DataType != Int8Oid {
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read %v but received: %v", Int8Oid, fd.DataType)))
|
qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, fd.DataType)))
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,7 +270,7 @@ func encodeInt8(w *WriteBuf, value interface{}) error {
|
|||||||
|
|
||||||
func decodeInt2(qr *QueryResult, fd *FieldDescription, size int32) int16 {
|
func decodeInt2(qr *QueryResult, fd *FieldDescription, size int32) int16 {
|
||||||
if fd.DataType != Int2Oid {
|
if fd.DataType != Int2Oid {
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read %v but received: %v", Int2Oid, fd.DataType)))
|
qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, fd.DataType)))
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -346,7 +346,7 @@ func encodeInt2(w *WriteBuf, value interface{}) error {
|
|||||||
|
|
||||||
func decodeInt4(qr *QueryResult, fd *FieldDescription, size int32) int32 {
|
func decodeInt4(qr *QueryResult, fd *FieldDescription, size int32) int32 {
|
||||||
if fd.DataType != Int4Oid {
|
if fd.DataType != Int4Oid {
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read %v but received: %v", Int4Oid, fd.DataType)))
|
qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, fd.DataType)))
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -530,6 +530,8 @@ func decodeBytea(qr *QueryResult, fd *FieldDescription, size int32) []byte {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
|
case BinaryFormatCode:
|
||||||
|
return qr.mr.ReadBytes(size)
|
||||||
default:
|
default:
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
||||||
return nil
|
return nil
|
||||||
@@ -552,7 +554,7 @@ func decodeDate(qr *QueryResult, fd *FieldDescription, size int32) time.Time {
|
|||||||
var zeroTime time.Time
|
var zeroTime time.Time
|
||||||
|
|
||||||
if fd.DataType != DateOid {
|
if fd.DataType != DateOid {
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read date but received: %v", fd.DataType)))
|
qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, fd.DataType)))
|
||||||
return zeroTime
|
return zeroTime
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -591,7 +593,7 @@ func decodeTimestampTz(qr *QueryResult, fd *FieldDescription, size int32) time.T
|
|||||||
var zeroTime time.Time
|
var zeroTime time.Time
|
||||||
|
|
||||||
if fd.DataType != TimestampTzOid {
|
if fd.DataType != TimestampTzOid {
|
||||||
qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read timestamptz but received: %v", fd.DataType)))
|
qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, fd.DataType)))
|
||||||
return zeroTime
|
return zeroTime
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -78,30 +78,6 @@ func TestSanitizeSql(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEncodeError(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
conn := mustConnect(t, *defaultConnConfig)
|
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
mustPrepare(t, conn, "testTranscode", "select $1::integer")
|
|
||||||
defer func() {
|
|
||||||
if err := conn.Deallocate("testTranscode"); err != nil {
|
|
||||||
t.Fatalf("Unable to deallocate prepared statement: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err := conn.Query("testTranscode", "wrong")
|
|
||||||
switch {
|
|
||||||
case err == nil:
|
|
||||||
t.Error("Expected transcode error to return error, but it didn't")
|
|
||||||
case err.Error() == "Expected integer representable in int32, received string wrong":
|
|
||||||
// Correct behavior
|
|
||||||
default:
|
|
||||||
t.Errorf("Expected transcode error, received %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
func TestNilTranscode(t *testing.T) {
|
func TestNilTranscode(t *testing.T) {
|
||||||
// t.Parallel()
|
// t.Parallel()
|
||||||
|
|||||||
Reference in New Issue
Block a user