diff --git a/query.go b/query.go index 97af8719..a6f8fc34 100644 --- a/query.go +++ b/query.go @@ -214,7 +214,13 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { case *bool: *d = decodeBool(vr) case *[]byte: - *d = decodeBytea(vr) + // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format) + // Otherwise read the bytes directly regardless of what the actual type is. + if vr.Type().DataType == ByteaOid { + *d = decodeBytea(vr) + } else { + *d = vr.ReadBytes(vr.Len()) + } case *int64: *d = decodeInt8(vr) case *int16: diff --git a/query_test.go b/query_test.go index f43a66e1..956e38c2 100644 --- a/query_test.go +++ b/query_test.go @@ -383,29 +383,36 @@ func TestQueryRowCoreTypes(t *testing.T) { } } -func TestQueryRowCoreBytea(t *testing.T) { +func TestQueryRowCoreByteSlice(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - var actual []byte - sql := "select $1::bytea" - queryArg := []byte{0, 15, 255, 17} - expected := []byte{0, 15, 255, 17} - - actual = nil - - err := conn.QueryRow(sql, queryArg).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + tests := []struct { + sql string + queryArg interface{} + expected []byte + }{ + {"select $1::text", "Jack", []byte("Jack")}, + {"select $1::varchar", "Jack", []byte("Jack")}, + {"select $1::bytea", []byte{0, 15, 255, 17}, []byte{0, 15, 255, 17}}, } - if bytes.Compare(actual, expected) != 0 { - t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql) - } + for i, tt := range tests { + var actual []byte - ensureConnValid(t, conn) + err := conn.QueryRow(tt.sql, tt.queryArg).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) + } + + if bytes.Compare(actual, tt.expected) != 0 { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) + } + + ensureConnValid(t, conn) + } } func TestQueryRowUnknownType(t *testing.T) {