From fa36ad91967c7a90f61cfd8f14231d7b8cfe8785 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Mar 2017 19:39:57 -0600 Subject: [PATCH] Move "char" to pgtype --- conn.go | 1 + pgtype/pgtype_test.go | 22 +++++-- pgtype/qchar.go | 144 ++++++++++++++++++++++++++++++++++++++++++ pgtype/qchar_test.go | 140 ++++++++++++++++++++++++++++++++++++++++ values.go | 80 ----------------------- values_test.go | 4 -- 6 files changed, 300 insertions(+), 91 deletions(-) create mode 100644 pgtype/qchar.go create mode 100644 pgtype/qchar_test.go diff --git a/conn.go b/conn.go index 023b9d97..f9f94c43 100644 --- a/conn.go +++ b/conn.go @@ -270,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.oidPgtypeValues = map[OID]pgtype.Value{ BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, + CharOID: &pgtype.QChar{}, CIDOID: &pgtype.CID{}, CidrArrayOID: &pgtype.CidrArray{}, CidrOID: &pgtype.Inet{}, diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 304fd0ea..c1dba383 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -74,12 +74,15 @@ func (f forceBinaryEncoder) EncodeBinary(w io.Writer) error { func forceEncoder(e interface{}, formatCode int16) interface{} { switch formatCode { case pgx.TextFormatCode: - return forceTextEncoder{e: e.(pgtype.TextEncoder)} + if e, ok := e.(pgtype.TextEncoder); ok { + return forceTextEncoder{e: e} + } case pgx.BinaryFormatCode: - return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} - default: - panic("bad encoder") + if e, ok := e.(pgtype.BinaryEncoder); ok { + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + } } + return nil } func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { @@ -105,9 +108,14 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, } - for _, fc := range formats { - ps.FieldDescriptions[0].FormatCode = fc.formatCode - for i, v := range values { + for i, v := range values { + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := forceEncoder(v, fc.formatCode) + if vEncoder == nil { + t.Logf("%v does not implement %v", fc.name) + continue + } // Derefence value if it is a pointer derefV := v refVal := reflect.ValueOf(v) diff --git a/pgtype/qchar.go b/pgtype/qchar.go new file mode 100644 index 00000000..6dd14625 --- /dev/null +++ b/pgtype/qchar.go @@ -0,0 +1,144 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +// QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C +// language's char type, or Go's byte type. (Note that the name in PostgreSQL +// itself is "char", in double-quotes, and not char.) It gets used a lot in +// PostgreSQL's system tables to hold a single ASCII character value (eg +// pg_class.relkind). It is named Qchar for quoted char to disambiguate from SQL +// standard type char. +// +// Not all possible values of QChar are representable in the text format. +// Therefore, QChar does not implement TextEncoder and TextDecoder. +type QChar struct { + Int int8 + Status Status +} + +func (dst *QChar) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case QChar: + *dst = value + case int8: + *dst = QChar{Int: value, Status: Present} + case uint8: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int16: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint16: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int32: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint32: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int64: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint64: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 8) + if err != nil { + return err + } + *dst = QChar{Int: int8(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to QChar", value) + } + + return nil +} + +func (src *QChar) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *QChar) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = QChar{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf(`invalid length for "char": %v`, size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *dst = QChar{Int: int8(byt), Status: Present} + return nil +} + +func (src QChar) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + return pgio.WriteByte(w, byte(src.Int)) +} diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go new file mode 100644 index 00000000..ea7b56a8 --- /dev/null +++ b/pgtype/qchar_test.go @@ -0,0 +1,140 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestQCharTranscode(t *testing.T) { + testSuccessfulTranscode(t, `"char"`, []interface{}{ + pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, + pgtype.QChar{Int: -1, Status: pgtype.Present}, + pgtype.QChar{Int: 0, Status: pgtype.Present}, + pgtype.QChar{Int: 1, Status: pgtype.Present}, + pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, + pgtype.QChar{Int: 0, Status: pgtype.Null}, + }) +} + +func TestQCharConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.QChar + }{ + {source: int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.QChar + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestQCharAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.QChar + dst interface{} + }{ + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/values.go b/values.go index 8e7ef4ac..c724aa39 100644 --- a/values.go +++ b/values.go @@ -371,52 +371,6 @@ func (n NullAclItem) Encode(w *WriteBuf, oid OID) error { return encodeString(w, oid, string(n.AclItem)) } -// The pgx.Char type is for PostgreSQL's special 8-bit-only -// "char" type more akin to the C language's char type, or Go's byte type. -// (Note that the name in PostgreSQL itself is "char", in double-quotes, -// and not char.) It gets used a lot in PostgreSQL's system tables to hold -// a single ASCII character value (eg pg_class.relkind). -type Char byte - -// NullChar represents a pgx.Char that may be null. NullChar implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullChar struct { - Char Char - Valid bool // Valid is true if Char is not NULL -} - -func (n *NullChar) Scan(vr *ValueReader) error { - if vr.Type().DataType != CharOID { - return SerializationError(fmt.Sprintf("NullChar.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Char, n.Valid = 0, false - return nil - } - n.Valid = true - n.Char = decodeChar(vr) - return vr.Err() -} - -func (n NullChar) FormatCode() int16 { return BinaryFormatCode } - -func (n NullChar) Encode(w *WriteBuf, oid OID) error { - if oid != CharOID { - return SerializationError(fmt.Sprintf("NullChar.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeChar(w, oid, n.Char) -} - // NullInt16 represents a smallint that may be null. NullInt16 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan for prepared and unprepared queries. @@ -945,8 +899,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { } switch arg := arg.(type) { - case Char: - return encodeChar(wbuf, oid, arg) case AclItem: // The aclitem data type goes over the wire using the same format as string, // so just cast to string and use encodeString @@ -1018,8 +970,6 @@ func decodeByOID(vr *ValueReader) (interface{}, error) { // decoding to the built-in functionality. func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { - case *Char: - *v = decodeChar(vr) case *AclItem: // aclitem goes over the wire just like text *v = AclItem(decodeText(vr)) @@ -1158,30 +1108,6 @@ func decodeInt8(vr *ValueReader) int64 { return n.Int } -func decodeChar(vr *ValueReader) Char { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into char")) - return Char(0) - } - - if vr.Type().DataType != CharOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into char", vr.Type().DataType))) - return Char(0) - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Char(0) - } - - if vr.Len() != 1 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a char: %d", vr.Len()))) - return Char(0) - } - - return Char(vr.ReadByte()) -} - func decodeInt2(vr *ValueReader) int16 { if vr.Type().DataType != Int2OID { @@ -1216,12 +1142,6 @@ func decodeInt2(vr *ValueReader) int16 { return n.Int } -func encodeChar(w *WriteBuf, oid OID, value Char) error { - w.WriteInt32(1) - w.WriteByte(byte(value)) - return nil -} - func decodeInt4(vr *ValueReader) int32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into int32")) diff --git a/values_test.go b/values_test.go index 0e51effe..4c02ac0a 100644 --- a/values_test.go +++ b/values_test.go @@ -568,7 +568,6 @@ func TestNullX(t *testing.T) { s pgx.NullString i16 pgx.NullInt16 i32 pgx.NullInt32 - c pgx.NullChar a pgx.NullAclItem tid pgx.NullTid i64 pgx.NullInt64 @@ -592,9 +591,6 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, - {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, - {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks