2
0

Merge remote-tracking branch 'upstream/master'

This commit is contained in:
2024-03-28 17:29:07 +03:00
101 changed files with 1509 additions and 591 deletions
+3 -1
View File
@@ -176,8 +176,10 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error {
bitLen := int32(binary.BigEndian.Uint32(src))
rp := 4
buf := make([]byte, len(src[rp:]))
copy(buf, src[rp:])
return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true})
return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true})
}
type scanPlanTextAnyToBitsScanner struct{}
+2 -2
View File
@@ -297,12 +297,12 @@ func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr
return nil, nil
}
var n float64
var n float32
err := codecScan(c, m, oid, format, src, &n)
if err != nil {
return nil, err
}
return n, nil
return float64(n), nil
}
func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
+17
View File
@@ -25,6 +25,11 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
case []byte:
return encodePlanJSONCodecEitherFormatByteSlice{}
// Handle json.RawMessage specifically because if it is run through json.Marshal it may be mutated.
// e.g. `{"foo": "bar"}` -> `{"foo":"bar"}`.
case json.RawMessage:
return encodePlanJSONCodecEitherFormatJSONRawMessage{}
// Cannot rely on driver.Valuer being handled later because anything can be marshalled.
//
// https://github.com/jackc/pgx/issues/1430
@@ -79,6 +84,18 @@ func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (n
return buf, nil
}
type encodePlanJSONCodecEitherFormatJSONRawMessage struct{}
func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes := value.(json.RawMessage)
if jsonBytes == nil {
return nil, nil
}
buf = append(buf, jsonBytes...)
return buf, nil
}
type encodePlanJSONCodecEitherFormatMarshal struct{}
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
+122
View File
@@ -0,0 +1,122 @@
package pgtype
import (
"database/sql/driver"
"fmt"
)
type LtreeCodec struct{}
func (l LtreeCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
// PreferredFormat returns the preferred format.
func (l LtreeCodec) PreferredFormat() int16 {
return TextFormatCode
}
// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be
// found then nil is returned.
func (l LtreeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch format {
case TextFormatCode:
return (TextCodec)(l).PlanEncode(m, oid, format, value)
case BinaryFormatCode:
switch value.(type) {
case string:
return encodeLtreeCodecBinaryString{}
case []byte:
return encodeLtreeCodecBinaryByteSlice{}
case TextValuer:
return encodeLtreeCodecBinaryTextValuer{}
}
}
return nil
}
type encodeLtreeCodecBinaryString struct{}
func (encodeLtreeCodecBinaryString) Encode(value any, buf []byte) (newBuf []byte, err error) {
ltree := value.(string)
buf = append(buf, 1)
return append(buf, ltree...), nil
}
type encodeLtreeCodecBinaryByteSlice struct{}
func (encodeLtreeCodecBinaryByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) {
ltree := value.([]byte)
buf = append(buf, 1)
return append(buf, ltree...), nil
}
type encodeLtreeCodecBinaryTextValuer struct{}
func (encodeLtreeCodecBinaryTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
t, err := value.(TextValuer).TextValue()
if err != nil {
return nil, err
}
if !t.Valid {
return nil, nil
}
buf = append(buf, 1)
return append(buf, t.String...), nil
}
// PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If
// no plan can be found then nil is returned.
func (l LtreeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case TextFormatCode:
return (TextCodec)(l).PlanScan(m, oid, format, target)
case BinaryFormatCode:
switch target.(type) {
case *string:
return scanPlanBinaryLtreeToString{}
case TextScanner:
return scanPlanBinaryLtreeToTextScanner{}
}
}
return nil
}
type scanPlanBinaryLtreeToString struct{}
func (scanPlanBinaryLtreeToString) Scan(src []byte, target any) error {
version := src[0]
if version != 1 {
return fmt.Errorf("unsupported ltree version %d", version)
}
p := (target).(*string)
*p = string(src[1:])
return nil
}
type scanPlanBinaryLtreeToTextScanner struct{}
func (scanPlanBinaryLtreeToTextScanner) Scan(src []byte, target any) error {
version := src[0]
if version != 1 {
return fmt.Errorf("unsupported ltree version %d", version)
}
scanner := (target).(TextScanner)
return scanner.ScanText(Text{String: string(src[1:]), Valid: true})
}
// DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface.
func (l LtreeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return (TextCodec)(l).DecodeDatabaseSQLValue(m, oid, format, src)
}
// DecodeValue returns src decoded into its default format.
func (l LtreeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
return (TextCodec)(l).DecodeValue(m, oid, format, src)
}
+26
View File
@@ -0,0 +1,26 @@
package pgtype_test
import (
"context"
"testing"
"github.com/andoma-go/pgx/v5/pgtype"
"github.com/andoma-go/pgx/v5/pgxtest"
)
func TestLtreeCodec(t *testing.T) {
skipCockroachDB(t, "Server does not support type ltree")
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "ltree", []pgxtest.ValueRoundTripTest{
{
Param: "A.B.C",
Result: new(string),
Test: isExpectedEq("A.B.C"),
},
{
Param: pgtype.Text{String: "", Valid: true},
Result: new(pgtype.Text),
Test: isExpectedEq(pgtype.Text{String: "", Valid: true}),
},
})
}
+19
View File
@@ -48,4 +48,23 @@ func TestMacaddrCodec(t *testing.T) {
},
{nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))},
})
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "macaddr8", []pgxtest.ValueRoundTripTest{
{
mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"),
new(net.HardwareAddr),
isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")),
},
{
"01:23:45:67:89:ab:01:08",
new(net.HardwareAddr),
isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")),
},
{
mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"),
new(string),
isExpectedEq("01:23:45:67:89:ab:01:08"),
},
{nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))},
})
}
+4 -1
View File
@@ -41,6 +41,7 @@ const (
CircleOID = 718
CircleArrayOID = 719
UnknownOID = 705
Macaddr8OID = 774
MacaddrOID = 829
InetOID = 869
BoolArrayOID = 1000
@@ -81,6 +82,8 @@ const (
IntervalOID = 1186
IntervalArrayOID = 1187
NumericArrayOID = 1231
TimetzOID = 1266
TimetzArrayOID = 1270
BitOID = 1560
BitArrayOID = 1561
VarbitOID = 1562
@@ -559,7 +562,7 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex
}
}
if nextDstType != nil && dstValue.Type() != nextDstType {
if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) {
return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true
}
+3
View File
@@ -1,6 +1,7 @@
package pgtype
import (
"encoding/json"
"net"
"net/netip"
"reflect"
@@ -69,6 +70,7 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
defaultMap.RegisterType(&Type{Name: "macaddr8", OID: Macaddr8OID, Codec: MacaddrCodec{}})
defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}})
defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}})
@@ -173,6 +175,7 @@ func initDefaultMap() {
registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz")
registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval")
registerDefaultPgTypeVariants[string](defaultMap, "text")
registerDefaultPgTypeVariants[json.RawMessage](defaultMap, "json")
registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea")
registerDefaultPgTypeVariants[net.IP](defaultMap, "inet")
+17
View File
@@ -35,6 +35,7 @@ func init() {
// Test for renamed types
type _string string
type _bool bool
type _uint8 uint8
type _int8 int8
type _int16 int16
type _int16Slice []int16
@@ -453,6 +454,14 @@ func TestMapScanNullToWrongType(t *testing.T) {
assert.False(t, pn.Valid)
}
func TestScanToSliceOfRenamedUint8(t *testing.T) {
m := pgtype.NewMap()
var ruint8 []_uint8
err := m.Scan(pgtype.Int2ArrayOID, pgx.TextFormatCode, []byte("{2,4}"), &ruint8)
assert.NoError(t, err)
assert.Equal(t, []_uint8{2, 4}, ruint8)
}
func TestMapScanTextToBool(t *testing.T) {
tests := []struct {
name string
@@ -537,6 +546,14 @@ func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) {
require.Error(t, err)
}
// https://github.com/jackc/pgx/issues/1763
func TestMapEncodeRawJSONIntoUnknownOID(t *testing.T) {
m := pgtype.NewMap()
buf, err := m.Encode(0, pgtype.TextFormatCode, json.RawMessage(`{"foo": "bar"}`), nil)
require.NoError(t, err)
require.Equal(t, []byte(`{"foo": "bar"}`), buf)
}
func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) {
m := pgtype.NewMap()
src := []byte{0, 0, 0, 42}
+13 -1
View File
@@ -52,7 +52,19 @@ func parseUUID(src string) (dst [16]byte, err error) {
// encodeUUID converts a uuid byte array to UUID standard string form.
func encodeUUID(src [16]byte) string {
return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16])
var buf [36]byte
hex.Encode(buf[0:8], src[:4])
buf[8] = '-'
hex.Encode(buf[9:13], src[4:6])
buf[13] = '-'
hex.Encode(buf[14:18], src[6:8])
buf[18] = '-'
hex.Encode(buf[19:23], src[8:10])
buf[23] = '-'
hex.Encode(buf[24:], src[10:])
return string(buf[:])
}
// Scan implements the database/sql Scanner interface.