Allow reading any value into []byte
This commit is contained in:
@@ -214,7 +214,13 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||||||
case *bool:
|
case *bool:
|
||||||
*d = decodeBool(vr)
|
*d = decodeBool(vr)
|
||||||
case *[]byte:
|
case *[]byte:
|
||||||
*d = decodeBytea(vr)
|
// If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format)
|
||||||
|
// Otherwise read the bytes directly regardless of what the actual type is.
|
||||||
|
if vr.Type().DataType == ByteaOid {
|
||||||
|
*d = decodeBytea(vr)
|
||||||
|
} else {
|
||||||
|
*d = vr.ReadBytes(vr.Len())
|
||||||
|
}
|
||||||
case *int64:
|
case *int64:
|
||||||
*d = decodeInt8(vr)
|
*d = decodeInt8(vr)
|
||||||
case *int16:
|
case *int16:
|
||||||
|
|||||||
+22
-15
@@ -383,29 +383,36 @@ func TestQueryRowCoreTypes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryRowCoreBytea(t *testing.T) {
|
func TestQueryRowCoreByteSlice(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnect(t, *defaultConnConfig)
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
var actual []byte
|
tests := []struct {
|
||||||
sql := "select $1::bytea"
|
sql string
|
||||||
queryArg := []byte{0, 15, 255, 17}
|
queryArg interface{}
|
||||||
expected := []byte{0, 15, 255, 17}
|
expected []byte
|
||||||
|
}{
|
||||||
actual = nil
|
{"select $1::text", "Jack", []byte("Jack")},
|
||||||
|
{"select $1::varchar", "Jack", []byte("Jack")},
|
||||||
err := conn.QueryRow(sql, queryArg).Scan(&actual)
|
{"select $1::bytea", []byte{0, 15, 255, 17}, []byte{0, 15, 255, 17}},
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if bytes.Compare(actual, expected) != 0 {
|
for i, tt := range tests {
|
||||||
t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql)
|
var actual []byte
|
||||||
}
|
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
err := conn.QueryRow(tt.sql, tt.queryArg).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Compare(actual, tt.expected) != 0 {
|
||||||
|
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryRowUnknownType(t *testing.T) {
|
func TestQueryRowUnknownType(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user