Avoid type assertion in Scan
Before: BenchmarkConnInfoScanInt4IntoBinaryDecoder-16 79859755 14.6 ns/op 0 B/op 0 allocs/op BenchmarkConnInfoScanInt4IntoGoInt32-16 38969991 30.0 ns/op 0 B/op 0 allocs/op After: BenchmarkConnInfoScanInt4IntoBinaryDecoder-16 458046958 13.3 ns/op 0 B/op 0 allocs/op BenchmarkConnInfoScanInt4IntoGoInt32-16 275791776 20.6 ns/op 0 B/op 0 allocs/op
This commit is contained in:
@@ -164,8 +164,12 @@ var errBadStatus = errors.New("invalid status")
|
||||
|
||||
type DataType struct {
|
||||
Value Value
|
||||
Name string
|
||||
OID uint32
|
||||
|
||||
textDecoder TextDecoder
|
||||
binaryDecoder BinaryDecoder
|
||||
|
||||
Name string
|
||||
OID uint32
|
||||
}
|
||||
|
||||
type ConnInfo struct {
|
||||
@@ -285,6 +289,14 @@ func (ci *ConnInfo) RegisterDataType(t DataType) {
|
||||
}
|
||||
ci.oidToResultFormatCode[t.OID] = formatCode
|
||||
}
|
||||
|
||||
if d, ok := t.Value.(TextDecoder); ok {
|
||||
t.textDecoder = d
|
||||
}
|
||||
|
||||
if d, ok := t.Value.(BinaryDecoder); ok {
|
||||
t.binaryDecoder = d
|
||||
}
|
||||
}
|
||||
|
||||
func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) {
|
||||
@@ -374,25 +386,24 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac
|
||||
}
|
||||
|
||||
if dt, ok := ci.DataTypeForOID(oid); ok {
|
||||
value := dt.Value
|
||||
switch formatCode {
|
||||
case TextFormatCode:
|
||||
if textDecoder, ok := value.(TextDecoder); ok {
|
||||
err := textDecoder.DecodeText(ci, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return errors.Errorf("%T is not a pgtype.TextDecoder", value)
|
||||
}
|
||||
case BinaryFormatCode:
|
||||
if binaryDecoder, ok := value.(BinaryDecoder); ok {
|
||||
err := binaryDecoder.DecodeBinary(ci, buf)
|
||||
if dt.binaryDecoder != nil {
|
||||
err := dt.binaryDecoder.DecodeBinary(ci, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return errors.Errorf("%T is not a pgtype.BinaryDecoder", value)
|
||||
return errors.Errorf("%T is not a pgtype.BinaryDecoder", dt.Value)
|
||||
}
|
||||
case TextFormatCode:
|
||||
if dt.textDecoder != nil {
|
||||
err := dt.textDecoder.DecodeText(ci, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return errors.Errorf("%T is not a pgtype.TextDecoder", dt.Value)
|
||||
}
|
||||
default:
|
||||
return errors.Errorf("unknown format code: %v", formatCode)
|
||||
@@ -400,7 +411,7 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac
|
||||
|
||||
if !isFastType {
|
||||
if scanner, ok := dest.(sql.Scanner); ok {
|
||||
sqlSrc, err := DatabaseSQLValue(ci, value)
|
||||
sqlSrc, err := DatabaseSQLValue(ci, dt.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -408,7 +419,7 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac
|
||||
}
|
||||
}
|
||||
|
||||
return value.AssignTo(dest)
|
||||
return dt.Value.AssignTo(dest)
|
||||
}
|
||||
|
||||
// We might be given a pointer to something that implements the decoder interface(s),
|
||||
|
||||
Reference in New Issue
Block a user