diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3df322df..ce2336a2 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1295,6 +1295,10 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { } } + if _, ok := value.(driver.Valuer); ok { + return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} + } + for _, f := range m.TryWrapEncodePlanFuncs { if wrapperPlan, nextValue, ok := f(value); ok { if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { @@ -1328,6 +1332,75 @@ func (encodePlanTextValuerToAnyTextFormat) Encode(value any, buf []byte) (newBuf return append(buf, t.String...), nil } +type encodePlanDriverValuer struct { + m *Map + oid uint32 + formatCode int16 +} + +func (plan *encodePlanDriverValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + dv := value.(driver.Valuer) + if dv == nil { + return nil, nil + } + v, err := dv.Value() + if err != nil { + return nil, err + } + if v == nil { + return nil, nil + } + + newBuf, err = plan.m.Encode(plan.oid, plan.formatCode, v, buf) + if err == nil { + return newBuf, nil + } + + s, ok := v.(string) + if !ok { + return nil, err + } + + var scannedValue any + scanErr := plan.m.Scan(plan.oid, TextFormatCode, []byte(s), &scannedValue) + if scanErr != nil { + return nil, err + } + + var err2 error + newBuf, err2 = plan.m.Encode(plan.oid, BinaryFormatCode, scannedValue, buf) + if err2 != nil { + return nil, err + } + + return newBuf, nil +} + +type encodePlanStringToBinary struct { + m *Map + oid uint32 +} + +func (plan *encodePlanStringToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + s, ok := value.(string) + if !ok { + return nil, fmt.Errorf("expected %v to be a string to attempt conversion to binary", value) + } + + var scannedValue any + err = plan.m.Scan(plan.oid, TextFormatCode, []byte(s), &scannedValue) + if err != nil { + return nil, fmt.Errorf("tried to scan %v to convert to binary but failed: %v", value, err) + } + + newBuf, err = plan.m.Encode(plan.oid, BinaryFormatCode, scannedValue, buf) + if err != nil { + return nil, fmt.Errorf("tried to encode %v to binary via scanning but failed: %v", value, err) + } + + return newBuf, nil +} + // TryWrapEncodePlanFunc is a function that tries to create a wrapper plan for value. If successful it returns a plan // that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted // by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it @@ -1873,17 +1946,6 @@ func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBu plan := m.PlanEncode(oid, formatCode, value) if plan == nil { - if dv, ok := value.(driver.Valuer); ok { - if dv == nil { - return nil, nil - } - v, err := dv.Value() - if err != nil { - return nil, err - } - return m.Encode(oid, formatCode, v, buf) - } - return nil, newEncodeError(value, m, oid, formatCode, errors.New("cannot find encode plan")) } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 11ae39e0..8363b3b8 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -3,7 +3,9 @@ package pgtype_test import ( "context" "database/sql" + "database/sql/driver" "errors" + "fmt" "net" "os" "regexp" @@ -265,6 +267,56 @@ func TestMapScanPtrToPtrToSlice(t *testing.T) { require.Equal(t, []string{"foo", "bar"}, *v) } +type databaseValuerString string + +func (s databaseValuerString) Value() (driver.Value, error) { + return fmt.Sprintf("%d", len(s)), nil +} + +// https://github.com/jackc/pgx/issues/1319 +func TestMapEncodeTextFormatDatabaseValuerThatIsRenamedSimpleType(t *testing.T) { + m := pgtype.NewMap() + src := databaseValuerString("foo") + buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil) + require.NoError(t, err) + require.Equal(t, "3", string(buf)) +} + +type databaseValuerFmtStringer string + +func (s databaseValuerFmtStringer) Value() (driver.Value, error) { + return nil, nil +} + +func (s databaseValuerFmtStringer) String() string { + return "foobar" +} + +// https://github.com/jackc/pgx/issues/1311 +func TestMapEncodeTextFormatDatabaseValuerThatIsFmtStringer(t *testing.T) { + m := pgtype.NewMap() + src := databaseValuerFmtStringer("") + buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil) + require.NoError(t, err) + require.Nil(t, buf) +} + +type databaseValuerStringFormat struct { + n int32 +} + +func (v databaseValuerStringFormat) Value() (driver.Value, error) { + return fmt.Sprint(v.n), nil +} + +func TestMapEncodeBinaryFormatDatabaseValuerThatReturnsString(t *testing.T) { + m := pgtype.NewMap() + src := databaseValuerStringFormat{n: 42} + buf, err := m.Encode(pgtype.Int4OID, pgtype.BinaryFormatCode, src, nil) + require.NoError(t, err) + require.Equal(t, []byte{0, 0, 0, 42}, buf) +} + func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) { m := pgtype.NewMap() src := []byte{0, 0, 0, 42}