diff --git a/pgtype.go b/pgtype.go index f6c354ef..bb0a99af 100644 --- a/pgtype.go +++ b/pgtype.go @@ -164,8 +164,12 @@ var errBadStatus = errors.New("invalid status") type DataType struct { Value Value - Name string - OID uint32 + + textDecoder TextDecoder + binaryDecoder BinaryDecoder + + Name string + OID uint32 } type ConnInfo struct { @@ -285,6 +289,14 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { } ci.oidToResultFormatCode[t.OID] = formatCode } + + if d, ok := t.Value.(TextDecoder); ok { + t.textDecoder = d + } + + if d, ok := t.Value.(BinaryDecoder); ok { + t.binaryDecoder = d + } } func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) { @@ -374,25 +386,24 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac } if dt, ok := ci.DataTypeForOID(oid); ok { - value := dt.Value switch formatCode { - case TextFormatCode: - if textDecoder, ok := value.(TextDecoder); ok { - err := textDecoder.DecodeText(ci, buf) - if err != nil { - return err - } - } else { - return errors.Errorf("%T is not a pgtype.TextDecoder", value) - } case BinaryFormatCode: - if binaryDecoder, ok := value.(BinaryDecoder); ok { - err := binaryDecoder.DecodeBinary(ci, buf) + if dt.binaryDecoder != nil { + err := dt.binaryDecoder.DecodeBinary(ci, buf) if err != nil { return err } } else { - return errors.Errorf("%T is not a pgtype.BinaryDecoder", value) + return errors.Errorf("%T is not a pgtype.BinaryDecoder", dt.Value) + } + case TextFormatCode: + if dt.textDecoder != nil { + err := dt.textDecoder.DecodeText(ci, buf) + if err != nil { + return err + } + } else { + return errors.Errorf("%T is not a pgtype.TextDecoder", dt.Value) } default: return errors.Errorf("unknown format code: %v", formatCode) @@ -400,7 +411,7 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac if !isFastType { if scanner, ok := dest.(sql.Scanner); ok { - sqlSrc, err := DatabaseSQLValue(ci, value) + sqlSrc, err := DatabaseSQLValue(ci, dt.Value) if err != nil { return err } @@ -408,7 +419,7 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac } } - return value.AssignTo(dest) + return dt.Value.AssignTo(dest) } // We might be given a pointer to something that implements the decoder interface(s),