diff --git a/query_test.go b/query_test.go index 8af481f2..2f825a09 100644 --- a/query_test.go +++ b/query_test.go @@ -408,6 +408,15 @@ func TestQueryRowCoreTypes(t *testing.T) { } ensureConnValid(t, conn) + + // Check that Scan errors when a core type is null + err = conn.QueryRow(sql, nil).Scan(tt.scanArgs...) + if err == nil { + t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, sql) + } + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, sql) + } } } } diff --git a/values.go b/values.go index 325c198a..38f78119 100644 --- a/values.go +++ b/values.go @@ -527,6 +527,11 @@ func sanitizeArg(arg interface{}) (string, error) { } func decodeBool(vr *ValueReader) bool { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into bool")) + return false + } + switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) @@ -571,6 +576,11 @@ func encodeBool(w *WriteBuf, value interface{}) error { } func decodeInt8(vr *ValueReader) int64 { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into int64")) + return 0 + } + if vr.Type().DataType != Int8Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, vr.Type().DataType))) return 0 @@ -632,6 +642,11 @@ func encodeInt8(w *WriteBuf, value interface{}) error { } func decodeInt2(vr *ValueReader) int16 { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into int16")) + return 0 + } + if vr.Type().DataType != Int2Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, vr.Type().DataType))) return 0 @@ -708,6 +723,11 @@ func encodeInt2(w *WriteBuf, value interface{}) error { } func decodeInt4(vr *ValueReader) int32 { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into int32")) + return 0 + } + if vr.Type().DataType != Int4Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, vr.Type().DataType))) return 0 @@ -777,6 +797,11 @@ func encodeInt4(w *WriteBuf, value interface{}) error { } func decodeFloat4(vr *ValueReader) float32 { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into float32")) + return 0 + } + switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) @@ -824,6 +849,11 @@ func encodeFloat4(w *WriteBuf, value interface{}) error { } func decodeFloat8(vr *ValueReader) float64 { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into float64")) + return 0 + } + switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) @@ -868,6 +898,11 @@ func encodeFloat8(w *WriteBuf, value interface{}) error { } func decodeText(vr *ValueReader) string { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into string")) + return "" + } + return vr.ReadString(vr.Len()) } @@ -884,6 +919,10 @@ func encodeText(w *WriteBuf, value interface{}) error { } func decodeBytea(vr *ValueReader) []byte { + if vr.Len() == -1 { + return nil + } + switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) @@ -916,6 +955,11 @@ func encodeBytea(w *WriteBuf, value interface{}) error { func decodeDate(vr *ValueReader) time.Time { var zeroTime time.Time + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into time.Time")) + return zeroTime + } + if vr.Type().DataType != DateOid { vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, vr.Type().DataType))) return zeroTime @@ -957,6 +1001,11 @@ const microsecFromUnixEpochToY2K = 946684800 * 1000000 func decodeTimestampTz(vr *ValueReader) time.Time { var zeroTime time.Time + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into time.Time")) + return zeroTime + } + if vr.Type().DataType != TimestampTzOid { vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, vr.Type().DataType))) return zeroTime @@ -1002,6 +1051,11 @@ func encodeTimestampTz(w *WriteBuf, value interface{}) error { func decodeTimestamp(vr *ValueReader) time.Time { var zeroTime time.Time + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into timestamp")) + return zeroTime + } + if vr.Type().DataType != TimestampOid { vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampOid, vr.Type().DataType))) return zeroTime