From 222e3b37bc0c3b252e4d1aea7523b67df4fa38ee Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Oct 2022 12:20:23 -0500 Subject: [PATCH] Prefer driver.Value over wrap plans when encoding This is tricky due to driver.Valuer returning any. For example, we can plan for fmt.Stringer because it always returns a string. Because of this driver.Valuer was always handled as the last option. But with pgx v5 now having the ability to find underlying types like a string and supporting fmt.Stringer it meant that driver.Valuer was often not getting called because something else was found first. This change tries driver.Valuer immediately after the initial PlanScan for the Codec. So a type that directly implements a pgx interface should be used, but driver.Valuer will be prefered before all the attempts to handle renamed types, pointer deferencing, etc. fixes https://github.com/jackc/pgx/issues/1319 fixes https://github.com/jackc/pgx/issues/1311 --- pgtype/pgtype.go | 84 +++++++++++++++++++++++++++++++++++++------ pgtype/pgtype_test.go | 52 +++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 11 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3df322df..ce2336a2 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1295,6 +1295,10 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { } } + if _, ok := value.(driver.Valuer); ok { + return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} + } + for _, f := range m.TryWrapEncodePlanFuncs { if wrapperPlan, nextValue, ok := f(value); ok { if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { @@ -1328,6 +1332,75 @@ func (encodePlanTextValuerToAnyTextFormat) Encode(value any, buf []byte) (newBuf return append(buf, t.String...), nil } +type encodePlanDriverValuer struct { + m *Map + oid uint32 + formatCode int16 +} + +func (plan *encodePlanDriverValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + dv := value.(driver.Valuer) + if dv == nil { + return nil, nil + } + v, err := dv.Value() + if err != nil { + return nil, err + } + if v == nil { + return nil, nil + } + + newBuf, err = plan.m.Encode(plan.oid, plan.formatCode, v, buf) + if err == nil { + return newBuf, nil + } + + s, ok := v.(string) + if !ok { + return nil, err + } + + var scannedValue any + scanErr := plan.m.Scan(plan.oid, TextFormatCode, []byte(s), &scannedValue) + if scanErr != nil { + return nil, err + } + + var err2 error + newBuf, err2 = plan.m.Encode(plan.oid, BinaryFormatCode, scannedValue, buf) + if err2 != nil { + return nil, err + } + + return newBuf, nil +} + +type encodePlanStringToBinary struct { + m *Map + oid uint32 +} + +func (plan *encodePlanStringToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + s, ok := value.(string) + if !ok { + return nil, fmt.Errorf("expected %v to be a string to attempt conversion to binary", value) + } + + var scannedValue any + err = plan.m.Scan(plan.oid, TextFormatCode, []byte(s), &scannedValue) + if err != nil { + return nil, fmt.Errorf("tried to scan %v to convert to binary but failed: %v", value, err) + } + + newBuf, err = plan.m.Encode(plan.oid, BinaryFormatCode, scannedValue, buf) + if err != nil { + return nil, fmt.Errorf("tried to encode %v to binary via scanning but failed: %v", value, err) + } + + return newBuf, nil +} + // TryWrapEncodePlanFunc is a function that tries to create a wrapper plan for value. If successful it returns a plan // that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted // by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it @@ -1873,17 +1946,6 @@ func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBu plan := m.PlanEncode(oid, formatCode, value) if plan == nil { - if dv, ok := value.(driver.Valuer); ok { - if dv == nil { - return nil, nil - } - v, err := dv.Value() - if err != nil { - return nil, err - } - return m.Encode(oid, formatCode, v, buf) - } - return nil, newEncodeError(value, m, oid, formatCode, errors.New("cannot find encode plan")) } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 11ae39e0..8363b3b8 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -3,7 +3,9 @@ package pgtype_test import ( "context" "database/sql" + "database/sql/driver" "errors" + "fmt" "net" "os" "regexp" @@ -265,6 +267,56 @@ func TestMapScanPtrToPtrToSlice(t *testing.T) { require.Equal(t, []string{"foo", "bar"}, *v) } +type databaseValuerString string + +func (s databaseValuerString) Value() (driver.Value, error) { + return fmt.Sprintf("%d", len(s)), nil +} + +// https://github.com/jackc/pgx/issues/1319 +func TestMapEncodeTextFormatDatabaseValuerThatIsRenamedSimpleType(t *testing.T) { + m := pgtype.NewMap() + src := databaseValuerString("foo") + buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil) + require.NoError(t, err) + require.Equal(t, "3", string(buf)) +} + +type databaseValuerFmtStringer string + +func (s databaseValuerFmtStringer) Value() (driver.Value, error) { + return nil, nil +} + +func (s databaseValuerFmtStringer) String() string { + return "foobar" +} + +// https://github.com/jackc/pgx/issues/1311 +func TestMapEncodeTextFormatDatabaseValuerThatIsFmtStringer(t *testing.T) { + m := pgtype.NewMap() + src := databaseValuerFmtStringer("") + buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil) + require.NoError(t, err) + require.Nil(t, buf) +} + +type databaseValuerStringFormat struct { + n int32 +} + +func (v databaseValuerStringFormat) Value() (driver.Value, error) { + return fmt.Sprint(v.n), nil +} + +func TestMapEncodeBinaryFormatDatabaseValuerThatReturnsString(t *testing.T) { + m := pgtype.NewMap() + src := databaseValuerStringFormat{n: 42} + buf, err := m.Encode(pgtype.Int4OID, pgtype.BinaryFormatCode, src, nil) + require.NoError(t, err) + require.Equal(t, []byte{0, 0, 0, 42}, buf) +} + func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) { m := pgtype.NewMap() src := []byte{0, 0, 0, 42}