Handle Scanning null into core types
Return error instead of panic.
This commit is contained in:
@@ -408,6 +408,15 @@ func TestQueryRowCoreTypes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
|
|
||||||
|
// Check that Scan errors when a core type is null
|
||||||
|
err = conn.QueryRow(sql, nil).Scan(tt.scanArgs...)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, sql)
|
||||||
|
}
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "Cannot decode null") {
|
||||||
|
t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, sql)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -527,6 +527,11 @@ func sanitizeArg(arg interface{}) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodeBool(vr *ValueReader) bool {
|
func decodeBool(vr *ValueReader) bool {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into bool"))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
switch vr.Type().FormatCode {
|
switch vr.Type().FormatCode {
|
||||||
case TextFormatCode:
|
case TextFormatCode:
|
||||||
s := vr.ReadString(vr.Len())
|
s := vr.ReadString(vr.Len())
|
||||||
@@ -571,6 +576,11 @@ func encodeBool(w *WriteBuf, value interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodeInt8(vr *ValueReader) int64 {
|
func decodeInt8(vr *ValueReader) int64 {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into int64"))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
if vr.Type().DataType != Int8Oid {
|
if vr.Type().DataType != Int8Oid {
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, vr.Type().DataType)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, vr.Type().DataType)))
|
||||||
return 0
|
return 0
|
||||||
@@ -632,6 +642,11 @@ func encodeInt8(w *WriteBuf, value interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodeInt2(vr *ValueReader) int16 {
|
func decodeInt2(vr *ValueReader) int16 {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into int16"))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
if vr.Type().DataType != Int2Oid {
|
if vr.Type().DataType != Int2Oid {
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, vr.Type().DataType)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, vr.Type().DataType)))
|
||||||
return 0
|
return 0
|
||||||
@@ -708,6 +723,11 @@ func encodeInt2(w *WriteBuf, value interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodeInt4(vr *ValueReader) int32 {
|
func decodeInt4(vr *ValueReader) int32 {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into int32"))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
if vr.Type().DataType != Int4Oid {
|
if vr.Type().DataType != Int4Oid {
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, vr.Type().DataType)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, vr.Type().DataType)))
|
||||||
return 0
|
return 0
|
||||||
@@ -777,6 +797,11 @@ func encodeInt4(w *WriteBuf, value interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodeFloat4(vr *ValueReader) float32 {
|
func decodeFloat4(vr *ValueReader) float32 {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into float32"))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
switch vr.Type().FormatCode {
|
switch vr.Type().FormatCode {
|
||||||
case TextFormatCode:
|
case TextFormatCode:
|
||||||
s := vr.ReadString(vr.Len())
|
s := vr.ReadString(vr.Len())
|
||||||
@@ -824,6 +849,11 @@ func encodeFloat4(w *WriteBuf, value interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodeFloat8(vr *ValueReader) float64 {
|
func decodeFloat8(vr *ValueReader) float64 {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into float64"))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
switch vr.Type().FormatCode {
|
switch vr.Type().FormatCode {
|
||||||
case TextFormatCode:
|
case TextFormatCode:
|
||||||
s := vr.ReadString(vr.Len())
|
s := vr.ReadString(vr.Len())
|
||||||
@@ -868,6 +898,11 @@ func encodeFloat8(w *WriteBuf, value interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodeText(vr *ValueReader) string {
|
func decodeText(vr *ValueReader) string {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into string"))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
return vr.ReadString(vr.Len())
|
return vr.ReadString(vr.Len())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -884,6 +919,10 @@ func encodeText(w *WriteBuf, value interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodeBytea(vr *ValueReader) []byte {
|
func decodeBytea(vr *ValueReader) []byte {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
switch vr.Type().FormatCode {
|
switch vr.Type().FormatCode {
|
||||||
case TextFormatCode:
|
case TextFormatCode:
|
||||||
s := vr.ReadString(vr.Len())
|
s := vr.ReadString(vr.Len())
|
||||||
@@ -916,6 +955,11 @@ func encodeBytea(w *WriteBuf, value interface{}) error {
|
|||||||
func decodeDate(vr *ValueReader) time.Time {
|
func decodeDate(vr *ValueReader) time.Time {
|
||||||
var zeroTime time.Time
|
var zeroTime time.Time
|
||||||
|
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into time.Time"))
|
||||||
|
return zeroTime
|
||||||
|
}
|
||||||
|
|
||||||
if vr.Type().DataType != DateOid {
|
if vr.Type().DataType != DateOid {
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, vr.Type().DataType)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, vr.Type().DataType)))
|
||||||
return zeroTime
|
return zeroTime
|
||||||
@@ -957,6 +1001,11 @@ const microsecFromUnixEpochToY2K = 946684800 * 1000000
|
|||||||
func decodeTimestampTz(vr *ValueReader) time.Time {
|
func decodeTimestampTz(vr *ValueReader) time.Time {
|
||||||
var zeroTime time.Time
|
var zeroTime time.Time
|
||||||
|
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into time.Time"))
|
||||||
|
return zeroTime
|
||||||
|
}
|
||||||
|
|
||||||
if vr.Type().DataType != TimestampTzOid {
|
if vr.Type().DataType != TimestampTzOid {
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, vr.Type().DataType)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, vr.Type().DataType)))
|
||||||
return zeroTime
|
return zeroTime
|
||||||
@@ -1002,6 +1051,11 @@ func encodeTimestampTz(w *WriteBuf, value interface{}) error {
|
|||||||
func decodeTimestamp(vr *ValueReader) time.Time {
|
func decodeTimestamp(vr *ValueReader) time.Time {
|
||||||
var zeroTime time.Time
|
var zeroTime time.Time
|
||||||
|
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into timestamp"))
|
||||||
|
return zeroTime
|
||||||
|
}
|
||||||
|
|
||||||
if vr.Type().DataType != TimestampOid {
|
if vr.Type().DataType != TimestampOid {
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampOid, vr.Type().DataType)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampOid, vr.Type().DataType)))
|
||||||
return zeroTime
|
return zeroTime
|
||||||
|
|||||||
Reference in New Issue
Block a user