diff --git a/bpchar.go b/bpchar.go index f82e3724..e4d058e9 100644 --- a/bpchar.go +++ b/bpchar.go @@ -33,6 +33,10 @@ func (src *BPChar) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } +func (BPChar) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *BPChar) DecodeText(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeText(ci, src) } @@ -41,6 +45,10 @@ func (dst *BPChar) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } +func (BPChar) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (Text)(src).EncodeText(ci, buf) } diff --git a/enum_type.go b/enum_type.go index 6f52817a..1a6a4b46 100644 --- a/enum_type.go +++ b/enum_type.go @@ -128,6 +128,10 @@ func (src *enumType) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (enumType) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *enumType) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { dst.status = Null @@ -152,6 +156,10 @@ func (dst *enumType) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } +func (enumType) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src enumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.status { case Null: diff --git a/json.go b/json.go index c642c727..922da50d 100644 --- a/json.go +++ b/json.go @@ -113,6 +113,10 @@ func (src *JSON) AssignTo(dst interface{}) error { return nil } +func (JSON) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = JSON{Status: Null} @@ -127,6 +131,10 @@ func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } +func (JSON) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: diff --git a/jsonb.go b/jsonb.go index 984c0973..c129ac9b 100644 --- a/jsonb.go +++ b/jsonb.go @@ -20,6 +20,10 @@ func (src *JSONB) AssignTo(dst interface{}) error { return (*JSON)(src).AssignTo(dst) } +func (JSONB) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { return (*JSON)(dst).DecodeText(ci, src) } @@ -43,6 +47,10 @@ func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { } +func (JSONB) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (JSON)(src).EncodeText(ci, buf) } diff --git a/pgtype.go b/pgtype.go index 7c893360..ed676929 100644 --- a/pgtype.go +++ b/pgtype.go @@ -142,6 +142,18 @@ type TypeValue interface { TypeName() string } +// ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from +// whether it is also a BinaryDecoder. +type ResultFormatPreferrer interface { + PreferredResultFormat() int16 +} + +// ParamFormatPreferrer allows a type to specify its preferred param format instead of it being inferred from +// whether it is also a BinaryEncoder. +type ParamFormatPreferrer interface { + PreferredParamFormat() int16 +} + type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder takes ownership of src. The @@ -364,7 +376,9 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { { var formatCode int16 - if _, ok := t.Value.(BinaryEncoder); ok { + if pfp, ok := t.Value.(ParamFormatPreferrer); ok { + formatCode = pfp.PreferredParamFormat() + } else if _, ok := t.Value.(BinaryEncoder); ok { formatCode = BinaryFormatCode } ci.oidToParamFormatCode[t.OID] = formatCode @@ -372,7 +386,9 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { { var formatCode int16 - if _, ok := t.Value.(BinaryDecoder); ok { + if rfp, ok := t.Value.(ResultFormatPreferrer); ok { + formatCode = rfp.PreferredResultFormat() + } else if _, ok := t.Value.(BinaryDecoder); ok { formatCode = BinaryFormatCode } ci.oidToResultFormatCode[t.OID] = formatCode diff --git a/pgtype_test.go b/pgtype_test.go index e1c49666..a96720d5 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -44,6 +44,26 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } +func TestConnInfoResultFormatCodeForOID(t *testing.T) { + ci := pgtype.NewConnInfo() + + // pgtype.JSONB implements BinaryDecoder but also implements ResultFormatPreferrer to override it to text. + assert.Equal(t, int16(pgtype.TextFormatCode), ci.ResultFormatCodeForOID(pgtype.JSONBOID)) + + // pgtype.Int4 implements BinaryDecoder but does not implement ResultFormatPreferrer so it should be binary. + assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ResultFormatCodeForOID(pgtype.Int4OID)) +} + +func TestConnInfoParamFormatCodeForOID(t *testing.T) { + ci := pgtype.NewConnInfo() + + // pgtype.JSONB implements BinaryEncoder but also implements ParamFormatPreferrer to override it to text. + assert.Equal(t, int16(pgtype.TextFormatCode), ci.ParamFormatCodeForOID(pgtype.JSONBOID)) + + // pgtype.Int4 implements BinaryEncoder but does not implement ParamFormatPreferrer so it should be binary. + assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ParamFormatCodeForOID(pgtype.Int4OID)) +} + func TestConnInfoScanUnknownOIDToStringsAndBytes(t *testing.T) { unknownOID := uint32(999999) srcBuf := []byte("foo") diff --git a/text.go b/text.go index 1f5d2a37..4c9e4a21 100644 --- a/text.go +++ b/text.go @@ -85,6 +85,10 @@ func (src *Text) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } +func (Text) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Text{Status: Null} @@ -99,6 +103,10 @@ func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } +func (Text) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: diff --git a/varchar.go b/varchar.go index e4fa6869..fea31d18 100644 --- a/varchar.go +++ b/varchar.go @@ -23,6 +23,10 @@ func (src *Varchar) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } +func (Varchar) PreferredResultFormat() int16 { + return TextFormatCode +} + func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeText(ci, src) } @@ -31,6 +35,10 @@ func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } +func (Varchar) PreferredParamFormat() int16 { + return TextFormatCode +} + func (src Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (Text)(src).EncodeText(ci, buf) }