2
0

Replace EnumType with EnumCodec

This commit is contained in:
Jack Christensen
2022-01-11 20:46:10 -06:00
parent b57e0c419b
commit ae9be0b99e
5 changed files with 116 additions and 338 deletions
+1 -2
View File
@@ -918,8 +918,7 @@ func BenchmarkSelectManyRegisteredEnum(b *testing.B) {
err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid)
require.NoError(b, err)
et := pgtype.NewEnumType("color", []string{"blue", "green", "orange"})
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "color", OID: oid})
conn.ConnInfo().RegisterDataType(pgtype.DataType{Name: "color", OID: oid, Codec: &pgtype.EnumCodec{}})
b.ResetTimer()
var x, y, z string
+114
View File
@@ -0,0 +1,114 @@
package pgtype
import (
"database/sql/driver"
"fmt"
)
// EnumCodec is a codec that caches the strings it decodes. If the same string is read multiple times only one copy is
// allocated. These strings are only garbage collected when the EnumCodec is garbage collected. EnumCodec can be used
// for any text type not only enums, but it should only be used when there are a small number of possible values.
type EnumCodec struct {
membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating
}
func (EnumCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
func (EnumCodec) PreferredFormat() int16 {
return TextFormatCode
}
func (EnumCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan {
switch format {
case TextFormatCode, BinaryFormatCode:
switch value.(type) {
case string:
return encodePlanTextCodecString{}
case []byte:
return encodePlanTextCodecByteSlice{}
case rune:
return encodePlanTextCodecRune{}
case fmt.Stringer:
return encodePlanTextCodecStringer{}
case TextValuer:
return encodePlanTextCodecTextValuer{}
}
}
return nil
}
func (c *EnumCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
switch format {
case TextFormatCode, BinaryFormatCode:
switch target.(type) {
case *string:
return &scanPlanTextAnyToEnumString{codec: c}
case *[]byte:
return scanPlanAnyToNewByteSlice{}
case TextScanner:
return &scanPlanTextAnyToEnumTextScanner{codec: c}
case *rune:
return scanPlanTextAnyToRune{}
}
}
return nil
}
func (c *EnumCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) {
return c.DecodeValue(ci, oid, format, src)
}
func (c *EnumCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {
if src == nil {
return nil, nil
}
return c.lookupAndCacheString(src), nil
}
// lookupAndCacheString looks for src in the members map. If it is not found it is added to the map.
func (c *EnumCodec) lookupAndCacheString(src []byte) string {
if c.membersMap == nil {
c.membersMap = make(map[string]string)
}
if s, found := c.membersMap[string(src)]; found {
return s
} else {
c.membersMap[s] = s
return s
}
}
type scanPlanTextAnyToEnumString struct {
codec *EnumCodec
}
func (plan *scanPlanTextAnyToEnumString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
if src == nil {
return fmt.Errorf("cannot scan null into %T", dst)
}
p := (dst).(*string)
*p = plan.codec.lookupAndCacheString(src)
return nil
}
type scanPlanTextAnyToEnumTextScanner struct {
codec *EnumCodec
}
func (plan *scanPlanTextAnyToEnumTextScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
scanner := (dst).(TextScanner)
if src == nil {
return scanner.ScanText(Text{})
}
return scanner.ScanText(Text{String: plan.codec.lookupAndCacheString(src), Valid: true})
}
-158
View File
@@ -1,158 +0,0 @@
package pgtype
import "fmt"
// EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties
// when registered as a data type in a ConnType. It should not be used directly as a Value.
type EnumType struct {
value string
valid bool
typeName string // PostgreSQL type name
members []string // enum members
membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating
}
// NewEnumType initializes a new EnumType. It retains a read-only reference to members. members must not be changed.
func NewEnumType(typeName string, members []string) *EnumType {
et := &EnumType{typeName: typeName, members: members}
et.membersMap = make(map[string]string, len(members))
for _, m := range members {
et.membersMap[m] = m
}
return et
}
func (et *EnumType) NewTypeValue() Value {
return &EnumType{
value: et.value,
valid: et.valid,
typeName: et.typeName,
members: et.members,
membersMap: et.membersMap,
}
}
func (et *EnumType) TypeName() string {
return et.typeName
}
func (et *EnumType) Members() []string {
return et.members
}
// Set assigns src to dst. Set purposely does not check that src is a member. This allows continued error free
// operation in the event the PostgreSQL enum type is modified during a connection.
func (dst *EnumType) Set(src interface{}) error {
if src == nil {
dst.valid = false
return nil
}
if value, ok := src.(interface{ Get() interface{} }); ok {
value2 := value.Get()
if value2 != value {
return dst.Set(value2)
}
}
switch value := src.(type) {
case string:
dst.value = value
dst.valid = true
case *string:
if value == nil {
dst.valid = false
} else {
dst.value = *value
dst.valid = true
}
case []byte:
if value == nil {
dst.valid = false
} else {
dst.value = string(value)
dst.valid = true
}
default:
if originalSrc, ok := underlyingStringType(src); ok {
return dst.Set(originalSrc)
}
return fmt.Errorf("cannot convert %v to enum %s", value, dst.typeName)
}
return nil
}
func (dst EnumType) Get() interface{} {
if !dst.valid {
return nil
}
return dst.value
}
func (src *EnumType) AssignTo(dst interface{}) error {
if !src.valid {
return NullAssignTo(dst)
}
switch v := dst.(type) {
case *string:
*v = src.value
return nil
case *[]byte:
*v = make([]byte, len(src.value))
copy(*v, src.value)
return nil
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
return fmt.Errorf("unable to assign to %T", dst)
}
}
func (EnumType) PreferredResultFormat() int16 {
return TextFormatCode
}
func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
dst.valid = false
return nil
}
// Lookup the string in membersMap to avoid an allocation.
if s, found := dst.membersMap[string(src)]; found {
dst.value = s
} else {
// If an enum type is modified after the initial connection it is possible to receive an unexpected value.
// Gracefully handle this situation. Purposely NOT modifying members and membersMap to allow for sharing members
// and membersMap between connections.
dst.value = string(src)
}
dst.valid = true
return nil
}
func (dst *EnumType) DecodeBinary(ci *ConnInfo, src []byte) error {
return dst.DecodeText(ci, src)
}
func (EnumType) PreferredParamFormat() int16 {
return TextFormatCode
}
func (src EnumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
if !src.valid {
return nil, nil
}
return append(buf, src.value...), nil
}
func (src EnumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
return src.EncodeText(ci, buf)
}
-148
View File
@@ -1,148 +0,0 @@
package pgtype_test
import (
"bytes"
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgtype/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType {
_, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;")
require.NoError(t, err)
_, err = conn.Exec(context.Background(), "create type pgtype_enum_color as enum ('blue', 'green', 'purple');")
require.NoError(t, err)
var oid uint32
err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "pgtype_enum_color").Scan(&oid)
require.NoError(t, err)
et := pgtype.NewEnumType("pgtype_enum_color", []string{"blue", "green", "purple"})
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "pgtype_enum_color", OID: oid})
return et
}
func cleanupEnum(t *testing.T, conn *pgx.Conn) {
_, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;")
require.NoError(t, err)
}
func TestEnumTypeTranscode(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
setupEnum(t, conn)
defer cleanupEnum(t, conn)
var dst string
err := conn.QueryRow(context.Background(), "select $1::pgtype_enum_color", "blue").Scan(&dst)
require.NoError(t, err)
require.EqualValues(t, "blue", dst)
}
func TestEnumTypeSet(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
enumType := setupEnum(t, conn)
defer cleanupEnum(t, conn)
successfulTests := []struct {
source interface{}
result interface{}
}{
{source: "blue", result: "blue"},
{source: _string("green"), result: "green"},
{source: (*string)(nil), result: nil},
}
for i, tt := range successfulTests {
err := enumType.Set(tt.source)
assert.NoErrorf(t, err, "%d", i)
assert.Equalf(t, tt.result, enumType.Get(), "%d", i)
}
}
func TestEnumTypeAssignTo(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
enumType := setupEnum(t, conn)
defer cleanupEnum(t, conn)
{
var s string
err := enumType.Set("blue")
require.NoError(t, err)
err = enumType.AssignTo(&s)
require.NoError(t, err)
assert.EqualValues(t, "blue", s)
}
{
var ps *string
err := enumType.Set("blue")
require.NoError(t, err)
err = enumType.AssignTo(&ps)
require.NoError(t, err)
assert.EqualValues(t, "blue", *ps)
}
{
var ps *string
err := enumType.Set(nil)
require.NoError(t, err)
err = enumType.AssignTo(&ps)
require.NoError(t, err)
assert.EqualValues(t, (*string)(nil), ps)
}
var buf []byte
bytesTests := []struct {
src interface{}
dst *[]byte
expected []byte
}{
{src: "blue", dst: &buf, expected: []byte("blue")},
{src: nil, dst: &buf, expected: nil},
}
for i, tt := range bytesTests {
err := enumType.Set(tt.src)
require.NoError(t, err, "%d", i)
err = enumType.AssignTo(tt.dst)
require.NoError(t, err, "%d", i)
if bytes.Compare(*tt.dst, tt.expected) != 0 {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst)
}
}
{
var s string
err := enumType.Set(nil)
require.NoError(t, err)
err = enumType.AssignTo(&s)
require.Error(t, err)
}
}
+1 -30
View File
@@ -65,11 +65,7 @@ func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeNa
// }
// return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil
case "e": // enum
members, err := GetEnumMembers(ctx, conn, oid)
if err != nil {
return pgtype.DataType{}, err
}
return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil
return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
default:
return pgtype.DataType{}, errors.New("unknown typtype")
}
@@ -121,28 +117,3 @@ func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32,
// return fields, nil
// }
// GetEnumMembers gets the possible values of the enum by oid.
func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) {
members := []string{}
rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid)
if err != nil {
return nil, err
}
for rows.Next() {
var m string
err := rows.Scan(&m)
if err != nil {
return nil, err
}
members = append(members, m)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return members, nil
}