From 0bbaad1348a924d630c9ed4c68d6d999f94021da Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Jan 2020 11:23:28 -0600 Subject: [PATCH] Add zeronull package for easier NULL <-> zero conversion --- testutil/testutil.go | 150 +++++++++++++++++++++++++++++++++++ zeronull/doc.go | 22 +++++ zeronull/int2.go | 90 +++++++++++++++++++++ zeronull/int2_test.go | 23 ++++++ zeronull/int4.go | 90 +++++++++++++++++++++ zeronull/int4_test.go | 23 ++++++ zeronull/int8.go | 90 +++++++++++++++++++++ zeronull/int8_test.go | 23 ++++++ zeronull/text.go | 90 +++++++++++++++++++++ zeronull/text_test.go | 23 ++++++ zeronull/timestamp.go | 91 +++++++++++++++++++++ zeronull/timestamp_test.go | 29 +++++++ zeronull/timestamptz.go | 91 +++++++++++++++++++++ zeronull/timestamptz_test.go | 29 +++++++ zeronull/uuid.go | 90 +++++++++++++++++++++ zeronull/uuid_test.go | 23 ++++++ 16 files changed, 977 insertions(+) create mode 100644 zeronull/doc.go create mode 100644 zeronull/int2.go create mode 100644 zeronull/int2_test.go create mode 100644 zeronull/int4.go create mode 100644 zeronull/int4_test.go create mode 100644 zeronull/int8.go create mode 100644 zeronull/int8_test.go create mode 100644 zeronull/text.go create mode 100644 zeronull/text_test.go create mode 100644 zeronull/timestamp.go create mode 100644 zeronull/timestamp_test.go create mode 100644 zeronull/timestamptz.go create mode 100644 zeronull/timestamptz_test.go create mode 100644 zeronull/uuid.go create mode 100644 zeronull/uuid_test.go diff --git a/testutil/testutil.go b/testutil/testutil.go index 068b7c59..e7b64b58 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -284,3 +284,153 @@ func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, t } } } + +func TestGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { + TestPgxGoZeroToNullConversion(t, pgTypeName, zero) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLGoZeroToNullConversion(t, driverName, pgTypeName, zero) + } +} + +func TestNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { + TestPgxNullToGoZeroConversion(t, pgTypeName, zero) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLNullToGoZeroConversion(t, driverName, pgTypeName, zero) + } +} + +func TestPgxGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s is null", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, paramFormat := range formats { + vEncoder := ForceEncoder(zero, paramFormat.formatCode) + if vEncoder == nil { + t.Logf("Skipping Param %s: %#v does not implement %v for encoding", paramFormat.name, zero, paramFormat.name) + continue + } + + var result bool + err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result) + if err != nil { + t.Errorf("Param %s: %v", paramFormat.name, err) + } + + if !result { + t.Errorf("Param %s: did not convert zero to null", paramFormat.name) + } + } +} + +func TestPgxNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) { + conn := MustConnectPgx(t) + defer MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select null::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, resultFormat := range formats { + + switch resultFormat.formatCode { + case pgx.TextFormatCode: + if _, ok := zero.(pgtype.TextEncoder); !ok { + t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) + continue + } + case pgx.BinaryFormatCode: + if _, ok := zero.(pgtype.BinaryEncoder); !ok { + t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name) + continue + } + } + + // Derefence value if it is a pointer + derefZero := zero + refVal := reflect.ValueOf(zero) + if refVal.Kind() == reflect.Ptr { + derefZero = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefZero)) + + err := conn.QueryRow(context.Background(), "test").Scan(result.Interface()) + if err != nil { + t.Errorf("Result %s: %v", resultFormat.name, err) + } + + if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { + t.Errorf("Result %s: did not convert null to zero", resultFormat.name) + } + } +} + +func TestDatabaseSQLGoZeroToNullConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s is null", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + var result bool + err = ps.QueryRow(zero).Scan(&result) + if err != nil { + t.Errorf("%v %v", driverName, err) + } + + if !result { + t.Errorf("%v: did not convert zero to null", driverName) + } +} + +func TestDatabaseSQLNullToGoZeroConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select null::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + // Derefence value if it is a pointer + derefZero := zero + refVal := reflect.ValueOf(zero) + if refVal.Kind() == reflect.Ptr { + derefZero = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefZero)) + + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %v", driverName, err) + } + + if !reflect.DeepEqual(result.Elem().Interface(), derefZero) { + t.Errorf("%s: did not convert null to zero", driverName) + } +} diff --git a/zeronull/doc.go b/zeronull/doc.go new file mode 100644 index 00000000..8db3507c --- /dev/null +++ b/zeronull/doc.go @@ -0,0 +1,22 @@ +// Package zeronull contains types that automatically convert between database NULLs and Go zero values. +/* +Sometimes the distinction between a zero value and a NULL value is not useful at the application level. For example, +in PostgreSQL an empty string may be stored as NULL. There is usually no application level distinction between an +empty string and a NULL string. Package zeronull implements types that seemlessly convert between PostgreSQL NULL and +the zero value. + +It is recommended to convert types at usage time rather than instantiate these types directly. In the example below, +middlename would be stored as a NULL. + + firstname := "John" + middlename := "" + lastname := "Smith" + _, err := conn.Exec( + ctx, + "insert into people(firstname, middlename, lastname) values($1, $2, $3)", + zeronull.Text(firstname), + zeronull.Text(middlename), + zeronull.Text(lastname), + ) +*/ +package zeronull diff --git a/zeronull/int2.go b/zeronull/int2.go new file mode 100644 index 00000000..a528642f --- /dev/null +++ b/zeronull/int2.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Int2 int16 + +func (dst *Int2) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int2 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int2(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Int2) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int2 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int2(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (src Int2) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int2{ + Int: int16(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Int2) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int2{ + Int: int16(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int2 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int2(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/int2_test.go b/zeronull/int2_test.go new file mode 100644 index 00000000..2dcb4e79 --- /dev/null +++ b/zeronull/int2_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt2Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ + (zeronull.Int2)(1), + (zeronull.Int2)(0), + }) +} + +func TestInt2ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int2", (zeronull.Int2)(0)) +} + +func TestInt2ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int2", (zeronull.Int2)(0)) +} diff --git a/zeronull/int4.go b/zeronull/int4.go new file mode 100644 index 00000000..c539e43a --- /dev/null +++ b/zeronull/int4.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Int4 int32 + +func (dst *Int4) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int4 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int4(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Int4) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int4 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int4(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (src Int4) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int4{ + Int: int32(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Int4) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int4{ + Int: int32(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int4 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int4(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/int4_test.go b/zeronull/int4_test.go new file mode 100644 index 00000000..309e4125 --- /dev/null +++ b/zeronull/int4_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ + (zeronull.Int4)(1), + (zeronull.Int4)(0), + }) +} + +func TestInt4ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int4", (zeronull.Int4)(0)) +} + +func TestInt4ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int4", (zeronull.Int4)(0)) +} diff --git a/zeronull/int8.go b/zeronull/int8.go new file mode 100644 index 00000000..19774645 --- /dev/null +++ b/zeronull/int8.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Int8 int64 + +func (dst *Int8) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int8 + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int8(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (dst *Int8) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Int8 + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Int8(nullable.Int) + } else { + *dst = 0 + } + + return nil +} + +func (src Int8) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int8{ + Int: int64(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Int8) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == 0 { + return nil, nil + } + + nullable := pgtype.Int8{ + Int: int64(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = 0 + return nil + } + + var nullable pgtype.Int8 + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Int8(nullable.Int) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/int8_test.go b/zeronull/int8_test.go new file mode 100644 index 00000000..ae80bc0a --- /dev/null +++ b/zeronull/int8_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestInt8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ + (zeronull.Int8)(1), + (zeronull.Int8)(0), + }) +} + +func TestInt8ConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "int8", (zeronull.Int8)(0)) +} + +func TestInt8ConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "int8", (zeronull.Int8)(0)) +} diff --git a/zeronull/text.go b/zeronull/text.go new file mode 100644 index 00000000..8e79fc6a --- /dev/null +++ b/zeronull/text.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type Text string + +func (dst *Text) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Text + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Text(nullable.String) + } else { + *dst = Text("") + } + + return nil +} + +func (dst *Text) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Text + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Text(nullable.String) + } else { + *dst = Text("") + } + + return nil +} + +func (src Text) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == Text("") { + return nil, nil + } + + nullable := pgtype.Text{ + String: string(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Text) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if src == Text("") { + return nil, nil + } + + nullable := pgtype.Text{ + String: string(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src interface{}) error { + if src == nil { + *dst = Text("") + return nil + } + + var nullable pgtype.Text + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Text(nullable.String) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/text_test.go b/zeronull/text_test.go new file mode 100644 index 00000000..f08a0d2a --- /dev/null +++ b/zeronull/text_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTextTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "text", []interface{}{ + (zeronull.Text)("foo"), + (zeronull.Text)(""), + }) +} + +func TestTextConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "text", (zeronull.Text)("")) +} + +func TestTextConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "text", (zeronull.Text)("")) +} diff --git a/zeronull/timestamp.go b/zeronull/timestamp.go new file mode 100644 index 00000000..a94c67cc --- /dev/null +++ b/zeronull/timestamp.go @@ -0,0 +1,91 @@ +package zeronull + +import ( + "database/sql/driver" + "time" + + "github.com/jackc/pgtype" +) + +type Timestamp time.Time + +func (dst *Timestamp) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamp + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Timestamp(nullable.Time) + } else { + *dst = Timestamp{} + } + + return nil +} + +func (dst *Timestamp) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamp + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Timestamp(nullable.Time) + } else { + *dst = Timestamp{} + } + + return nil +} + +func (src Timestamp) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamp{}) { + return nil, nil + } + + nullable := pgtype.Timestamp{ + Time: time.Time(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Timestamp) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamp{}) { + return nil, nil + } + + nullable := pgtype.Timestamp{ + Time: time.Time(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamp) Scan(src interface{}) error { + if src == nil { + *dst = Timestamp{} + return nil + } + + var nullable pgtype.Timestamp + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Timestamp(nullable.Time) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamp) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/timestamp_test.go b/zeronull/timestamp_test.go new file mode 100644 index 00000000..ec96ff07 --- /dev/null +++ b/zeronull/timestamp_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTimestampTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + (zeronull.Timestamp)(time.Time{}), + }, func(a, b interface{}) bool { + at := a.(zeronull.Timestamp) + bt := b.(zeronull.Timestamp) + + return time.Time(at).Equal(time.Time(bt)) + }) +} + +func TestTimestampConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) +} + +func TestTimestampConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "timestamp", (zeronull.Timestamp)(time.Time{})) +} diff --git a/zeronull/timestamptz.go b/zeronull/timestamptz.go new file mode 100644 index 00000000..c641ca10 --- /dev/null +++ b/zeronull/timestamptz.go @@ -0,0 +1,91 @@ +package zeronull + +import ( + "database/sql/driver" + "time" + + "github.com/jackc/pgtype" +) + +type Timestamptz time.Time + +func (dst *Timestamptz) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamptz + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Timestamptz(nullable.Time) + } else { + *dst = Timestamptz{} + } + + return nil +} + +func (dst *Timestamptz) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.Timestamptz + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = Timestamptz(nullable.Time) + } else { + *dst = Timestamptz{} + } + + return nil +} + +func (src Timestamptz) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamptz{}) { + return nil, nil + } + + nullable := pgtype.Timestamptz{ + Time: time.Time(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src Timestamptz) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == Timestamptz{}) { + return nil, nil + } + + nullable := pgtype.Timestamptz{ + Time: time.Time(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamptz) Scan(src interface{}) error { + if src == nil { + *dst = Timestamptz{} + return nil + } + + var nullable pgtype.Timestamptz + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = Timestamptz(nullable.Time) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamptz) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/timestamptz_test.go b/zeronull/timestamptz_test.go new file mode 100644 index 00000000..3a401c49 --- /dev/null +++ b/zeronull/timestamptz_test.go @@ -0,0 +1,29 @@ +package zeronull_test + +import ( + "testing" + "time" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestTimestamptzTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + (zeronull.Timestamptz)(time.Time{}), + }, func(a, b interface{}) bool { + at := a.(zeronull.Timestamptz) + bt := b.(zeronull.Timestamptz) + + return time.Time(at).Equal(time.Time(bt)) + }) +} + +func TestTimestamptzConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) +} + +func TestTimestamptzConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "timestamptz", (zeronull.Timestamptz)(time.Time{})) +} diff --git a/zeronull/uuid.go b/zeronull/uuid.go new file mode 100644 index 00000000..18fc667e --- /dev/null +++ b/zeronull/uuid.go @@ -0,0 +1,90 @@ +package zeronull + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +type UUID [16]byte + +func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.UUID + err := nullable.DecodeText(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = UUID(nullable.Bytes) + } else { + *dst = UUID{} + } + + return nil +} + +func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + var nullable pgtype.UUID + err := nullable.DecodeBinary(ci, src) + if err != nil { + return err + } + + if nullable.Status == pgtype.Present { + *dst = UUID(nullable.Bytes) + } else { + *dst = UUID{} + } + + return nil +} + +func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == UUID{}) { + return nil, nil + } + + nullable := pgtype.UUID{ + Bytes: [16]byte(src), + Status: pgtype.Present, + } + + return nullable.EncodeText(ci, buf) +} + +func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + if (src == UUID{}) { + return nil, nil + } + + nullable := pgtype.UUID{ + Bytes: [16]byte(src), + Status: pgtype.Present, + } + + return nullable.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUID) Scan(src interface{}) error { + if src == nil { + *dst = UUID{} + return nil + } + + var nullable pgtype.UUID + err := nullable.Scan(src) + if err != nil { + return err + } + + *dst = UUID(nullable.Bytes) + + return nil +} + +// Value implements the database/sql/driver Valuer interface. +func (src UUID) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/zeronull/uuid_test.go b/zeronull/uuid_test.go new file mode 100644 index 00000000..162bdf1f --- /dev/null +++ b/zeronull/uuid_test.go @@ -0,0 +1,23 @@ +package zeronull_test + +import ( + "testing" + + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgtype/zeronull" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + (*zeronull.UUID)(&[16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + (*zeronull.UUID)(&[16]byte{}), + }) +} + +func TestUUIDConvertsGoZeroToNull(t *testing.T) { + testutil.TestGoZeroToNullConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) +} + +func TestUUIDConvertsNullToGoZero(t *testing.T) { + testutil.TestNullToGoZeroConversion(t, "uuid", (*zeronull.UUID)(&[16]byte{})) +}