Prefer driver.Value over wrap plans when encoding
This is tricky due to driver.Valuer returning any. For example, we can plan for fmt.Stringer because it always returns a string. Because of this driver.Valuer was always handled as the last option. But with pgx v5 now having the ability to find underlying types like a string and supporting fmt.Stringer it meant that driver.Valuer was often not getting called because something else was found first. This change tries driver.Valuer immediately after the initial PlanScan for the Codec. So a type that directly implements a pgx interface should be used, but driver.Valuer will be prefered before all the attempts to handle renamed types, pointer deferencing, etc. fixes https://github.com/jackc/pgx/issues/1319 fixes https://github.com/jackc/pgx/issues/1311
This commit is contained in:
+73
-11
@@ -1295,6 +1295,10 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan {
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := value.(driver.Valuer); ok {
|
||||
return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format}
|
||||
}
|
||||
|
||||
for _, f := range m.TryWrapEncodePlanFuncs {
|
||||
if wrapperPlan, nextValue, ok := f(value); ok {
|
||||
if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil {
|
||||
@@ -1328,6 +1332,75 @@ func (encodePlanTextValuerToAnyTextFormat) Encode(value any, buf []byte) (newBuf
|
||||
return append(buf, t.String...), nil
|
||||
}
|
||||
|
||||
type encodePlanDriverValuer struct {
|
||||
m *Map
|
||||
oid uint32
|
||||
formatCode int16
|
||||
}
|
||||
|
||||
func (plan *encodePlanDriverValuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
dv := value.(driver.Valuer)
|
||||
if dv == nil {
|
||||
return nil, nil
|
||||
}
|
||||
v, err := dv.Value()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
newBuf, err = plan.m.Encode(plan.oid, plan.formatCode, v, buf)
|
||||
if err == nil {
|
||||
return newBuf, nil
|
||||
}
|
||||
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var scannedValue any
|
||||
scanErr := plan.m.Scan(plan.oid, TextFormatCode, []byte(s), &scannedValue)
|
||||
if scanErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var err2 error
|
||||
newBuf, err2 = plan.m.Encode(plan.oid, BinaryFormatCode, scannedValue, buf)
|
||||
if err2 != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newBuf, nil
|
||||
}
|
||||
|
||||
type encodePlanStringToBinary struct {
|
||||
m *Map
|
||||
oid uint32
|
||||
}
|
||||
|
||||
func (plan *encodePlanStringToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
s, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected %v to be a string to attempt conversion to binary", value)
|
||||
}
|
||||
|
||||
var scannedValue any
|
||||
err = plan.m.Scan(plan.oid, TextFormatCode, []byte(s), &scannedValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tried to scan %v to convert to binary but failed: %v", value, err)
|
||||
}
|
||||
|
||||
newBuf, err = plan.m.Encode(plan.oid, BinaryFormatCode, scannedValue, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tried to encode %v to binary via scanning but failed: %v", value, err)
|
||||
}
|
||||
|
||||
return newBuf, nil
|
||||
}
|
||||
|
||||
// TryWrapEncodePlanFunc is a function that tries to create a wrapper plan for value. If successful it returns a plan
|
||||
// that will convert the value passed to Encode and then call the next plan. nextValue is value as it will be converted
|
||||
// by plan. It must be used to find another suitable EncodePlan. When it is found SetNext must be called on plan for it
|
||||
@@ -1873,17 +1946,6 @@ func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBu
|
||||
|
||||
plan := m.PlanEncode(oid, formatCode, value)
|
||||
if plan == nil {
|
||||
if dv, ok := value.(driver.Valuer); ok {
|
||||
if dv == nil {
|
||||
return nil, nil
|
||||
}
|
||||
v, err := dv.Value()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.Encode(oid, formatCode, v, buf)
|
||||
}
|
||||
|
||||
return nil, newEncodeError(value, m, oid, formatCode, errors.New("cannot find encode plan"))
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,9 @@ package pgtype_test
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -265,6 +267,56 @@ func TestMapScanPtrToPtrToSlice(t *testing.T) {
|
||||
require.Equal(t, []string{"foo", "bar"}, *v)
|
||||
}
|
||||
|
||||
type databaseValuerString string
|
||||
|
||||
func (s databaseValuerString) Value() (driver.Value, error) {
|
||||
return fmt.Sprintf("%d", len(s)), nil
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1319
|
||||
func TestMapEncodeTextFormatDatabaseValuerThatIsRenamedSimpleType(t *testing.T) {
|
||||
m := pgtype.NewMap()
|
||||
src := databaseValuerString("foo")
|
||||
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "3", string(buf))
|
||||
}
|
||||
|
||||
type databaseValuerFmtStringer string
|
||||
|
||||
func (s databaseValuerFmtStringer) Value() (driver.Value, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s databaseValuerFmtStringer) String() string {
|
||||
return "foobar"
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1311
|
||||
func TestMapEncodeTextFormatDatabaseValuerThatIsFmtStringer(t *testing.T) {
|
||||
m := pgtype.NewMap()
|
||||
src := databaseValuerFmtStringer("")
|
||||
buf, err := m.Encode(pgtype.TextOID, pgtype.TextFormatCode, src, nil)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, buf)
|
||||
}
|
||||
|
||||
type databaseValuerStringFormat struct {
|
||||
n int32
|
||||
}
|
||||
|
||||
func (v databaseValuerStringFormat) Value() (driver.Value, error) {
|
||||
return fmt.Sprint(v.n), nil
|
||||
}
|
||||
|
||||
func TestMapEncodeBinaryFormatDatabaseValuerThatReturnsString(t *testing.T) {
|
||||
m := pgtype.NewMap()
|
||||
src := databaseValuerStringFormat{n: 42}
|
||||
buf, err := m.Encode(pgtype.Int4OID, pgtype.BinaryFormatCode, src, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte{0, 0, 0, 42}, buf)
|
||||
}
|
||||
|
||||
func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) {
|
||||
m := pgtype.NewMap()
|
||||
src := []byte{0, 0, 0, 42}
|
||||
|
||||
Reference in New Issue
Block a user