From ae9be0b99ed53d3db94444a994b954fdf3f52d26 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 11 Jan 2022 20:46:10 -0600 Subject: [PATCH] Replace EnumType with EnumCodec --- bench_test.go | 3 +- pgtype/enum_codec.go | 114 +++++++++++++++++++++++++++ pgtype/enum_type.go | 158 -------------------------------------- pgtype/enum_type_test.go | 148 ----------------------------------- pgtype/pgxtype/pgxtype.go | 31 +------- 5 files changed, 116 insertions(+), 338 deletions(-) create mode 100644 pgtype/enum_codec.go delete mode 100644 pgtype/enum_type.go delete mode 100644 pgtype/enum_type_test.go diff --git a/bench_test.go b/bench_test.go index 9b14b7d3..c49c87f6 100644 --- a/bench_test.go +++ b/bench_test.go @@ -918,8 +918,7 @@ func BenchmarkSelectManyRegisteredEnum(b *testing.B) { err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid) require.NoError(b, err) - et := pgtype.NewEnumType("color", []string{"blue", "green", "orange"}) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "color", OID: oid}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}}) b.ResetTimer() var x, y, z string diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go new file mode 100644 index 00000000..9a37f1dd --- /dev/null +++ b/pgtype/enum_codec.go @@ -0,0 +1,114 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// EnumCodec is a codec that caches the strings it decodes. If the same string is read multiple times only one copy is +// allocated. These strings are only garbage collected when the EnumCodec is garbage collected. EnumCodec can be used +// for any text type not only enums, but it should only be used when there are a small number of possible values. +type EnumCodec struct { + membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating +} + +func (EnumCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (EnumCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (EnumCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch value.(type) { + case string: + return encodePlanTextCodecString{} + case []byte: + return encodePlanTextCodecByteSlice{} + case rune: + return encodePlanTextCodecRune{} + case fmt.Stringer: + return encodePlanTextCodecStringer{} + case TextValuer: + return encodePlanTextCodecTextValuer{} + } + } + + return nil +} + +func (c *EnumCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case TextFormatCode, BinaryFormatCode: + switch target.(type) { + case *string: + return &scanPlanTextAnyToEnumString{codec: c} + case *[]byte: + return scanPlanAnyToNewByteSlice{} + case TextScanner: + return &scanPlanTextAnyToEnumTextScanner{codec: c} + case *rune: + return scanPlanTextAnyToRune{} + } + } + + return nil +} + +func (c *EnumCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(ci, oid, format, src) +} + +func (c *EnumCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + return c.lookupAndCacheString(src), nil +} + +// lookupAndCacheString looks for src in the members map. If it is not found it is added to the map. +func (c *EnumCodec) lookupAndCacheString(src []byte) string { + if c.membersMap == nil { + c.membersMap = make(map[string]string) + } + + if s, found := c.membersMap[string(src)]; found { + return s + } else { + c.membersMap[s] = s + return s + } +} + +type scanPlanTextAnyToEnumString struct { + codec *EnumCodec +} + +func (plan *scanPlanTextAnyToEnumString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + p := (dst).(*string) + *p = plan.codec.lookupAndCacheString(src) + + return nil +} + +type scanPlanTextAnyToEnumTextScanner struct { + codec *EnumCodec +} + +func (plan *scanPlanTextAnyToEnumTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TextScanner) + + if src == nil { + return scanner.ScanText(Text{}) + } + + return scanner.ScanText(Text{String: plan.codec.lookupAndCacheString(src), Valid: true}) +} diff --git a/pgtype/enum_type.go b/pgtype/enum_type.go deleted file mode 100644 index 73ee3823..00000000 --- a/pgtype/enum_type.go +++ /dev/null @@ -1,158 +0,0 @@ -package pgtype - -import "fmt" - -// EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties -// when registered as a data type in a ConnType. It should not be used directly as a Value. -type EnumType struct { - value string - valid bool - - typeName string // PostgreSQL type name - members []string // enum members - membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating -} - -// NewEnumType initializes a new EnumType. It retains a read-only reference to members. members must not be changed. -func NewEnumType(typeName string, members []string) *EnumType { - et := &EnumType{typeName: typeName, members: members} - et.membersMap = make(map[string]string, len(members)) - for _, m := range members { - et.membersMap[m] = m - } - return et -} - -func (et *EnumType) NewTypeValue() Value { - return &EnumType{ - value: et.value, - valid: et.valid, - - typeName: et.typeName, - members: et.members, - membersMap: et.membersMap, - } -} - -func (et *EnumType) TypeName() string { - return et.typeName -} - -func (et *EnumType) Members() []string { - return et.members -} - -// Set assigns src to dst. Set purposely does not check that src is a member. This allows continued error free -// operation in the event the PostgreSQL enum type is modified during a connection. -func (dst *EnumType) Set(src interface{}) error { - if src == nil { - dst.valid = false - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case string: - dst.value = value - dst.valid = true - case *string: - if value == nil { - dst.valid = false - } else { - dst.value = *value - dst.valid = true - } - case []byte: - if value == nil { - dst.valid = false - } else { - dst.value = string(value) - dst.valid = true - } - default: - if originalSrc, ok := underlyingStringType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to enum %s", value, dst.typeName) - } - - return nil -} - -func (dst EnumType) Get() interface{} { - if !dst.valid { - return nil - } - return dst.value -} - -func (src *EnumType) AssignTo(dst interface{}) error { - if !src.valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *string: - *v = src.value - return nil - case *[]byte: - *v = make([]byte, len(src.value)) - copy(*v, src.value) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (EnumType) PreferredResultFormat() int16 { - return TextFormatCode -} - -func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - dst.valid = false - return nil - } - - // Lookup the string in membersMap to avoid an allocation. - if s, found := dst.membersMap[string(src)]; found { - dst.value = s - } else { - // If an enum type is modified after the initial connection it is possible to receive an unexpected value. - // Gracefully handle this situation. Purposely NOT modifying members and membersMap to allow for sharing members - // and membersMap between connections. - dst.value = string(src) - } - dst.valid = true - - return nil -} - -func (dst *EnumType) DecodeBinary(ci *ConnInfo, src []byte) error { - return dst.DecodeText(ci, src) -} - -func (EnumType) PreferredParamFormat() int16 { - return TextFormatCode -} - -func (src EnumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.valid { - return nil, nil - } - - return append(buf, src.value...), nil -} - -func (src EnumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return src.EncodeText(ci, buf) -} diff --git a/pgtype/enum_type_test.go b/pgtype/enum_type_test.go deleted file mode 100644 index 903b742f..00000000 --- a/pgtype/enum_type_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package pgtype_test - -import ( - "bytes" - "context" - "testing" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType { - _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") - require.NoError(t, err) - - _, err = conn.Exec(context.Background(), "create type pgtype_enum_color as enum ('blue', 'green', 'purple');") - require.NoError(t, err) - - var oid uint32 - err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "pgtype_enum_color").Scan(&oid) - require.NoError(t, err) - - et := pgtype.NewEnumType("pgtype_enum_color", []string{"blue", "green", "purple"}) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "pgtype_enum_color", OID: oid}) - - return et -} - -func cleanupEnum(t *testing.T, conn *pgx.Conn) { - _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") - require.NoError(t, err) -} - -func TestEnumTypeTranscode(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - setupEnum(t, conn) - defer cleanupEnum(t, conn) - - var dst string - err := conn.QueryRow(context.Background(), "select $1::pgtype_enum_color", "blue").Scan(&dst) - require.NoError(t, err) - require.EqualValues(t, "blue", dst) -} - -func TestEnumTypeSet(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - enumType := setupEnum(t, conn) - defer cleanupEnum(t, conn) - - successfulTests := []struct { - source interface{} - result interface{} - }{ - {source: "blue", result: "blue"}, - {source: _string("green"), result: "green"}, - {source: (*string)(nil), result: nil}, - } - - for i, tt := range successfulTests { - err := enumType.Set(tt.source) - assert.NoErrorf(t, err, "%d", i) - assert.Equalf(t, tt.result, enumType.Get(), "%d", i) - } -} - -func TestEnumTypeAssignTo(t *testing.T) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - enumType := setupEnum(t, conn) - defer cleanupEnum(t, conn) - - { - var s string - - err := enumType.Set("blue") - require.NoError(t, err) - - err = enumType.AssignTo(&s) - require.NoError(t, err) - - assert.EqualValues(t, "blue", s) - } - - { - var ps *string - - err := enumType.Set("blue") - require.NoError(t, err) - - err = enumType.AssignTo(&ps) - require.NoError(t, err) - - assert.EqualValues(t, "blue", *ps) - } - - { - var ps *string - - err := enumType.Set(nil) - require.NoError(t, err) - - err = enumType.AssignTo(&ps) - require.NoError(t, err) - - assert.EqualValues(t, (*string)(nil), ps) - } - - var buf []byte - bytesTests := []struct { - src interface{} - dst *[]byte - expected []byte - }{ - {src: "blue", dst: &buf, expected: []byte("blue")}, - {src: nil, dst: &buf, expected: nil}, - } - - for i, tt := range bytesTests { - err := enumType.Set(tt.src) - require.NoError(t, err, "%d", i) - - err = enumType.AssignTo(tt.dst) - require.NoError(t, err, "%d", i) - - if bytes.Compare(*tt.dst, tt.expected) != 0 { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) - } - } - - { - var s string - - err := enumType.Set(nil) - require.NoError(t, err) - - err = enumType.AssignTo(&s) - require.Error(t, err) - } - -} diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go index 4f2c5796..6b5068e2 100644 --- a/pgtype/pgxtype/pgxtype.go +++ b/pgtype/pgxtype/pgxtype.go @@ -65,11 +65,7 @@ func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeNa // } // return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil case "e": // enum - members, err := GetEnumMembers(ctx, conn, oid) - if err != nil { - return pgtype.DataType{}, err - } - return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil + return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil default: return pgtype.DataType{}, errors.New("unknown typtype") } @@ -121,28 +117,3 @@ func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, // return fields, nil // } - -// GetEnumMembers gets the possible values of the enum by oid. -func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) { - members := []string{} - - rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid) - if err != nil { - return nil, err - } - - for rows.Next() { - var m string - err := rows.Scan(&m) - if err != nil { - return nil, err - } - members = append(members, m) - } - - if rows.Err() != nil { - return nil, rows.Err() - } - - return members, nil -}