Handle Scanning null into core types
Return error instead of panic.
This commit is contained in:
@@ -408,6 +408,15 @@ func TestQueryRowCoreTypes(t *testing.T) {
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
|
||||
// Check that Scan errors when a core type is null
|
||||
err = conn.QueryRow(sql, nil).Scan(tt.scanArgs...)
|
||||
if err == nil {
|
||||
t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, sql)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||
t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -527,6 +527,11 @@ func sanitizeArg(arg interface{}) (string, error) {
|
||||
}
|
||||
|
||||
func decodeBool(vr *ValueReader) bool {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into bool"))
|
||||
return false
|
||||
}
|
||||
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
s := vr.ReadString(vr.Len())
|
||||
@@ -571,6 +576,11 @@ func encodeBool(w *WriteBuf, value interface{}) error {
|
||||
}
|
||||
|
||||
func decodeInt8(vr *ValueReader) int64 {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into int64"))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Int8Oid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, vr.Type().DataType)))
|
||||
return 0
|
||||
@@ -632,6 +642,11 @@ func encodeInt8(w *WriteBuf, value interface{}) error {
|
||||
}
|
||||
|
||||
func decodeInt2(vr *ValueReader) int16 {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into int16"))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Int2Oid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, vr.Type().DataType)))
|
||||
return 0
|
||||
@@ -708,6 +723,11 @@ func encodeInt2(w *WriteBuf, value interface{}) error {
|
||||
}
|
||||
|
||||
func decodeInt4(vr *ValueReader) int32 {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into int32"))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Int4Oid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, vr.Type().DataType)))
|
||||
return 0
|
||||
@@ -777,6 +797,11 @@ func encodeInt4(w *WriteBuf, value interface{}) error {
|
||||
}
|
||||
|
||||
func decodeFloat4(vr *ValueReader) float32 {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into float32"))
|
||||
return 0
|
||||
}
|
||||
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
s := vr.ReadString(vr.Len())
|
||||
@@ -824,6 +849,11 @@ func encodeFloat4(w *WriteBuf, value interface{}) error {
|
||||
}
|
||||
|
||||
func decodeFloat8(vr *ValueReader) float64 {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into float64"))
|
||||
return 0
|
||||
}
|
||||
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
s := vr.ReadString(vr.Len())
|
||||
@@ -868,6 +898,11 @@ func encodeFloat8(w *WriteBuf, value interface{}) error {
|
||||
}
|
||||
|
||||
func decodeText(vr *ValueReader) string {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into string"))
|
||||
return ""
|
||||
}
|
||||
|
||||
return vr.ReadString(vr.Len())
|
||||
}
|
||||
|
||||
@@ -884,6 +919,10 @@ func encodeText(w *WriteBuf, value interface{}) error {
|
||||
}
|
||||
|
||||
func decodeBytea(vr *ValueReader) []byte {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
s := vr.ReadString(vr.Len())
|
||||
@@ -916,6 +955,11 @@ func encodeBytea(w *WriteBuf, value interface{}) error {
|
||||
func decodeDate(vr *ValueReader) time.Time {
|
||||
var zeroTime time.Time
|
||||
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into time.Time"))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
if vr.Type().DataType != DateOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, vr.Type().DataType)))
|
||||
return zeroTime
|
||||
@@ -957,6 +1001,11 @@ const microsecFromUnixEpochToY2K = 946684800 * 1000000
|
||||
func decodeTimestampTz(vr *ValueReader) time.Time {
|
||||
var zeroTime time.Time
|
||||
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into time.Time"))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
if vr.Type().DataType != TimestampTzOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, vr.Type().DataType)))
|
||||
return zeroTime
|
||||
@@ -1002,6 +1051,11 @@ func encodeTimestampTz(w *WriteBuf, value interface{}) error {
|
||||
func decodeTimestamp(vr *ValueReader) time.Time {
|
||||
var zeroTime time.Time
|
||||
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into timestamp"))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
if vr.Type().DataType != TimestampOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampOid, vr.Type().DataType)))
|
||||
return zeroTime
|
||||
|
||||
Reference in New Issue
Block a user