diff --git a/pgtype/enum_codec.go b/pgtype/enum_codec.go index 3bf29f4a..970895d8 100644 --- a/pgtype/enum_codec.go +++ b/pgtype/enum_codec.go @@ -76,10 +76,11 @@ func (c *EnumCodec) lookupAndCacheString(src []byte) string { if s, found := c.membersMap[string(src)]; found { return s - } else { - c.membersMap[s] = s - return s } + + s := string(src) + c.membersMap[s] = s + return s } type scanPlanTextAnyToEnumString struct { diff --git a/pgtype/enum_codec_test.go b/pgtype/enum_codec_test.go new file mode 100644 index 00000000..139bfc34 --- /dev/null +++ b/pgtype/enum_codec_test.go @@ -0,0 +1,69 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestEnumCodec(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists enum_test; + +create type enum_test as enum ('foo', 'bar', 'baz');`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type enum_test") + + dt, err := conn.LoadDataType(context.Background(), "enum_test") + require.NoError(t, err) + + conn.ConnInfo().RegisterDataType(*dt) + + var s string + err = conn.QueryRow(context.Background(), `select 'foo'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "foo", s) + + err = conn.QueryRow(context.Background(), `select $1::enum_test`, "bar").Scan(&s) + require.NoError(t, err) + require.Equal(t, "bar", s) + + err = conn.QueryRow(context.Background(), `select 'foo'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "foo", s) + + err = conn.QueryRow(context.Background(), `select $1::enum_test`, "bar").Scan(&s) + require.NoError(t, err) + require.Equal(t, "bar", s) + + err = conn.QueryRow(context.Background(), `select 'baz'::enum_test`).Scan(&s) + require.NoError(t, err) + require.Equal(t, "baz", s) +} + +func TestEnumCodecValues(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists enum_test; + +create type enum_test as enum ('foo', 'bar', 'baz');`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type enum_test") + + dt, err := conn.LoadDataType(context.Background(), "enum_test") + require.NoError(t, err) + + conn.ConnInfo().RegisterDataType(*dt) + + rows, err := conn.Query(context.Background(), `select 'foo'::enum_test`) + require.NoError(t, err) + require.True(t, rows.Next()) + values, err := rows.Values() + require.NoError(t, err) + require.Equal(t, values, []interface{}{"foo"}) +}