diff --git a/pgtype.go b/pgtype.go index 1705ae41..d8dd5abf 100644 --- a/pgtype.go +++ b/pgtype.go @@ -179,12 +179,6 @@ type ResultFormatPreferrer interface { PreferredResultFormat() int16 } -// ParamFormatPreferrer allows a type to specify its preferred param format instead of it being inferred from -// whether it is also a BinaryEncoder. -type ParamFormatPreferrer interface { - PreferredParamFormat() int16 -} - type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder takes ownership of src. The @@ -243,7 +237,7 @@ type ConnInfo struct { oidToDataType map[uint32]*DataType nameToDataType map[string]*DataType reflectTypeToName map[reflect.Type]string - oidToParamFormatCode map[uint32]int16 + oidToFormatCode map[uint32]int16 oidToResultFormatCode map[uint32]int16 reflectTypeToDataType map[reflect.Type]*DataType @@ -256,7 +250,7 @@ func newConnInfo() *ConnInfo { oidToDataType: make(map[uint32]*DataType), nameToDataType: make(map[string]*DataType), reflectTypeToName: make(map[reflect.Type]string), - oidToParamFormatCode: make(map[uint32]int16), + oidToFormatCode: make(map[uint32]int16), oidToResultFormatCode: make(map[uint32]int16), preferAssignToOverSQLScannerTypes: make(map[reflect.Type]struct{}), } @@ -392,24 +386,12 @@ func (ci *ConnInfo) RegisterDataType(t DataType) { { var formatCode int16 - if pfp, ok := t.Value.(ParamFormatPreferrer); ok { - formatCode = pfp.PreferredParamFormat() + if pfp, ok := t.Value.(FormatSupport); ok { + formatCode = pfp.PreferredFormat() } else if _, ok := t.Value.(BinaryEncoder); ok { formatCode = BinaryFormatCode } - ci.oidToParamFormatCode[t.OID] = formatCode - } - - { - var formatCode int16 - if fs, ok := t.Value.(FormatSupport); ok { - formatCode = fs.PreferredFormat() - } else if rfp, ok := t.Value.(ResultFormatPreferrer); ok { - formatCode = rfp.PreferredResultFormat() - } else if _, ok := t.Value.(BinaryDecoder); ok { - formatCode = BinaryFormatCode - } - ci.oidToResultFormatCode[t.OID] = formatCode + ci.oidToFormatCode[t.OID] = formatCode } if d, ok := t.Value.(ResultDecoder); ok { @@ -477,16 +459,8 @@ func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) { return dt, ok } -func (ci *ConnInfo) ParamFormatCodeForOID(oid uint32) int16 { - fc, ok := ci.oidToParamFormatCode[oid] - if ok { - return fc - } - return TextFormatCode -} - -func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 { - fc, ok := ci.oidToResultFormatCode[oid] +func (ci *ConnInfo) FormatCodeForOID(oid uint32) int16 { + fc, ok := ci.oidToFormatCode[oid] if ok { return fc } diff --git a/pgtype_test.go b/pgtype_test.go index 7ae756e5..9bf1f242 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -67,24 +67,14 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } -func TestConnInfoResultFormatCodeForOID(t *testing.T) { - ci := pgtype.NewConnInfo() - - // pgtype.JSONB implements BinaryDecoder but also implements ResultFormatPreferrer to override it to text. - assert.Equal(t, int16(pgtype.TextFormatCode), ci.ResultFormatCodeForOID(pgtype.JSONBOID)) - - // pgtype.Int4 implements BinaryDecoder but does not implement ResultFormatPreferrer so it should be binary. - assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ResultFormatCodeForOID(pgtype.Int4OID)) -} - -func TestConnInfoParamFormatCodeForOID(t *testing.T) { +func TestConnInfoFormatCodeForOID(t *testing.T) { ci := pgtype.NewConnInfo() // pgtype.JSONB implements BinaryEncoder but also implements ParamFormatPreferrer to override it to text. - assert.Equal(t, int16(pgtype.TextFormatCode), ci.ParamFormatCodeForOID(pgtype.JSONBOID)) + assert.Equal(t, int16(pgtype.TextFormatCode), ci.FormatCodeForOID(pgtype.JSONBOID)) // pgtype.Int4 implements BinaryEncoder but does not implement ParamFormatPreferrer so it should be binary. - assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.ParamFormatCodeForOID(pgtype.Int4OID)) + assert.Equal(t, int16(pgtype.BinaryFormatCode), ci.FormatCodeForOID(pgtype.Int4OID)) } func TestConnInfoScanNilIsNoOp(t *testing.T) {