From 4ff46becfcca993860f84362fdcc19aa2b077d07 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 9 Sep 2015 18:07:05 -0500 Subject: [PATCH] Generalize pointer to string uuid transcoding to any non-varchar/text type --- conn.go | 10 +++++++--- values_test.go | 15 ++++++++++++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index b75c3ca9..8936e438 100644 --- a/conn.go +++ b/conn.go @@ -728,7 +728,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} switch arg := arguments[i].(type) { case Encoder: wbuf.WriteInt16(arg.FormatCode()) - case string: + case string, *string: wbuf.WriteInt16(TextFormatCode) default: switch oid { @@ -776,7 +776,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeFloat4(wbuf, arguments[i]) case Float8Oid: err = encodeFloat8(wbuf, arguments[i]) - case TextOid, VarcharOid, UuidOid: + case TextOid, VarcharOid: err = encodeText(wbuf, arguments[i]) case DateOid: err = encodeDate(wbuf, arguments[i]) @@ -811,7 +811,11 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} case JsonOid, JsonbOid: err = encodeJson(wbuf, arguments[i]) default: - return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + if s, ok := arguments[i].(string); ok { + err = encodeText(wbuf, s) + } else { + return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + } } } if err != nil { diff --git a/values_test.go b/values_test.go index 00a03ab8..b93542d5 100644 --- a/values_test.go +++ b/values_test.go @@ -199,20 +199,29 @@ func mustParseCIDR(t *testing.T, s string) net.IPNet { return *ipnet } -func TestUuidTranscode(t *testing.T) { +func TestStringToNotTextTypeTranscode(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) input := "01086ee0-4963-4e35-9116-30c173a8d0bd" + var output string - err := conn.QueryRow("select $1::uuid", &input).Scan(&output) + err := conn.QueryRow("select $1::uuid", input).Scan(&output) if err != nil { t.Fatal(err) } if input != output { - t.Errorf("uuid: Did not transcode successfully: %s is not %s", input, output) + t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output) + } + + err = conn.QueryRow("select $1::uuid", &input).Scan(&output) + if err != nil { + t.Fatal(err) + } + if input != output { + t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output) } }