diff --git a/conn.go b/conn.go index a65eb272..d564326c 100644 --- a/conn.go +++ b/conn.go @@ -763,7 +763,12 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = arg.EncodeBinary(wbuf) case TextEncoder: var s string - s, err = arg.EncodeText() + var status byte + s, status, err = arg.EncodeText() + if status == NullText { + wbuf.WriteInt32(-1) + continue + } wbuf.WriteInt32(int32(len(s))) wbuf.WriteBytes([]byte(s)) default: diff --git a/conn_test.go b/conn_test.go index 901fddfd..b0a5774d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -608,8 +608,8 @@ func TestQueryPreparedEncodeError(t *testing.T) { // works when the parameter type is a core type. type coreTextEncoder struct{} -func (n *coreTextEncoder) EncodeText() (string, error) { - return "42", nil +func (n *coreTextEncoder) EncodeText() (string, byte, error) { + return "42", pgx.SafeText, nil } func TestQueryPreparedEncodeCoreTextFormatError(t *testing.T) { diff --git a/values.go b/values.go index b48253ae..51f81b3e 100644 --- a/values.go +++ b/values.go @@ -30,6 +30,12 @@ const ( BinaryFormatCode = 1 ) +const ( + NullText = iota + SafeText = iota + UnsafeText = iota +) + type SerializationError string func (e SerializationError) Error() string { @@ -48,9 +54,10 @@ type Scanner interface { // queries and for prepared queries when the type does not implement // BinaryEncoder type TextEncoder interface { - // EncodeText MUST sanitize (and quote, if necessary) the returned string. - // It will be interpolated directly into the SQL string. - EncodeText() (string, error) + // EncodeText returns the value encoded into a string. status must be + // NullText if the value is NULL, UnsafeText if the value should be quoted + // and escaped, or SafeText if the value should not be quoted. + EncodeText() (val string, status byte, err error) } // BinaryEncoder is an interface used to encode values in binary format for @@ -81,11 +88,11 @@ func (n *NullFloat32) Scan(rows *Rows, fd *FieldDescription, size int32) error { return rows.Err() } -func (n NullFloat32) EncodeText() (string, error) { +func (n NullFloat32) EncodeText() (string, byte, error) { if n.Valid { - return strconv.FormatFloat(float64(n.Float32), 'f', -1, 32), nil + return strconv.FormatFloat(float64(n.Float32), 'f', -1, 32), SafeText, nil } else { - return "null", nil + return "", NullText, nil } } @@ -119,11 +126,11 @@ func (n *NullFloat64) Scan(rows *Rows, fd *FieldDescription, size int32) error { return rows.Err() } -func (n NullFloat64) EncodeText() (string, error) { +func (n NullFloat64) EncodeText() (string, byte, error) { if n.Valid { - return strconv.FormatFloat(n.Float64, 'f', -1, 64), nil + return strconv.FormatFloat(n.Float64, 'f', -1, 64), SafeText, nil } else { - return "null", nil + return "", NullText, nil } } @@ -136,6 +143,35 @@ func (n NullFloat64) EncodeBinary(w *WriteBuf) error { return encodeFloat8(w, n.Float64) } +// NullString represents an integer that may be null. NullString implements +// the Scanner and TextEncoder interfaces so it may be used both as an +// argument to Query[Row] and a destination for Scan for prepared and +// unprepared queries. +// +// If Valid is false then the value is NULL. +type NullString struct { + String string + Valid bool // Valid is true if Int64 is not NULL +} + +func (s *NullString) Scan(rows *Rows, fd *FieldDescription, size int32) error { + if size == -1 { + s.String, s.Valid = "", false + return nil + } + s.Valid = true + s.String = decodeText(rows, fd, size) + return rows.Err() +} + +func (s NullString) EncodeText() (string, byte, error) { + if s.Valid { + return s.String, UnsafeText, nil + } else { + return "", NullText, nil + } +} + // NullInt16 represents an smallint that may be null. // NullInt16 implements the Scanner, TextEncoder, and BinaryEncoder interfaces // so it may be used both as an argument to Query[Row] and a destination for @@ -157,11 +193,11 @@ func (n *NullInt16) Scan(rows *Rows, fd *FieldDescription, size int32) error { return rows.Err() } -func (n NullInt16) EncodeText() (string, error) { +func (n NullInt16) EncodeText() (string, byte, error) { if n.Valid { - return strconv.FormatInt(int64(n.Int16), 10), nil + return strconv.FormatInt(int64(n.Int16), 10), SafeText, nil } else { - return "null", nil + return "", NullText, nil } } @@ -174,7 +210,7 @@ func (n NullInt16) EncodeBinary(w *WriteBuf) error { return encodeInt2(w, n.Int16) } -// NullInt32 represents an integer that may be null. +// NullInt32 represents an smallint that may be null. // NullInt32 implements the Scanner, TextEncoder, and BinaryEncoder interfaces // so it may be used both as an argument to Query[Row] and a destination for // Scan for prepared and unprepared queries. @@ -195,11 +231,11 @@ func (n *NullInt32) Scan(rows *Rows, fd *FieldDescription, size int32) error { return rows.Err() } -func (n NullInt32) EncodeText() (string, error) { +func (n NullInt32) EncodeText() (string, byte, error) { if n.Valid { - return strconv.FormatInt(int64(n.Int32), 10), nil + return strconv.FormatInt(int64(n.Int32), 10), SafeText, nil } else { - return "null", nil + return "", NullText, nil } } @@ -233,11 +269,11 @@ func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error { return rows.Err() } -func (n NullInt64) EncodeText() (string, error) { +func (n NullInt64) EncodeText() (string, byte, error) { if n.Valid { - return strconv.FormatInt(int64(n.Int64), 10), nil + return strconv.FormatInt(int64(n.Int64), 10), SafeText, nil } else { - return "null", nil + return "", NullText, nil } } @@ -327,7 +363,17 @@ func sanitizeArg(arg interface{}) (string, error) { case nil: return "null", nil case TextEncoder: - return arg.EncodeText() + s, status, err := arg.EncodeText() + switch status { + case NullText: + return "null", err + case UnsafeText: + return QuoteString(s), err + case SafeText: + return s, err + default: + return "", SerializationError("Received invalid status from EncodeText") + } default: return "", SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg)) } diff --git a/values_test.go b/values_test.go index 434c7dcf..ad0b2742 100644 --- a/values_test.go +++ b/values_test.go @@ -194,6 +194,7 @@ func TestNullX(t *testing.T) { defer closeConn(t, conn) type allTypes struct { + s pgx.NullString i16 pgx.NullInt16 i32 pgx.NullInt32 i64 pgx.NullInt64 @@ -209,6 +210,8 @@ func TestNullX(t *testing.T) { scanArgs []interface{} expected allTypes }{ + {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "foo", Valid: true}}}, + {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: false}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "", Valid: false}}}, {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 1, Valid: true}}}, {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}},