Always use bound parameters
PostgreSQL has two string syntaxes, one that allows backslash escapes and one that does not (SQL standard conforming strings). By default PostgreSQL uses standard conforming strings. QuoteString was only designed for use with standard conforming strings. If PostgreSQL was configured with certain combinations of the standard_conforming_strings and backslash_quote settings, QuoteString may not correctly sanitize strings. QuoteString was only used in unprepared queries, bound parameters are used for prepared queries. This commit alters pgx to use always use bound parameters. As a consequence of never doing string interpolation there is no need to have separate Text and Binary encoders. There is now only the Encoder interface. This change had a negative effect on the performance of simple unprepared queries, but prepared statements should already be used for performance. fixes #26 https://github.com/jackc/pgx/issues/26
This commit is contained in:
@@ -142,11 +142,9 @@ configure the TLS connection.
|
|||||||
pgx includes support for the common data types like integers, floats, strings,
|
pgx includes support for the common data types like integers, floats, strings,
|
||||||
dates, and times that have direct mappings between Go and SQL. Support can be
|
dates, and times that have direct mappings between Go and SQL. Support can be
|
||||||
added for additional types like point, hstore, numeric, etc. that do not have
|
added for additional types like point, hstore, numeric, etc. that do not have
|
||||||
direct mappings in Go by the types implementing Scanner, TextEncoder, and
|
direct mappings in Go by the types implementing Scanner and Encoder. See
|
||||||
optionally BinaryEncoder. To enable binary format for custom types, a prepared
|
example_custom_type_test.go for an example of a custom type for the PostgreSQL
|
||||||
statement must be used and the field description of the returned field must have
|
point type.
|
||||||
FormatCode set to BinaryFormatCode. See example_custom_type_test.go for an
|
|
||||||
example of a custom type for the PostgreSQL point type.
|
|
||||||
|
|
||||||
### Null Mapping
|
### Null Mapping
|
||||||
|
|
||||||
|
|||||||
@@ -315,8 +315,40 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||||||
// Deallocate released a prepared statement
|
// Deallocate released a prepared statement
|
||||||
func (c *Conn) Deallocate(name string) (err error) {
|
func (c *Conn) Deallocate(name string) (err error) {
|
||||||
delete(c.preparedStatements, name)
|
delete(c.preparedStatements, name)
|
||||||
_, err = c.Exec("deallocate " + QuoteIdentifier(name))
|
|
||||||
return
|
// close
|
||||||
|
wbuf := newWriteBuf(c.wbuf[0:0], 'C')
|
||||||
|
wbuf.WriteByte('S')
|
||||||
|
wbuf.WriteCString(name)
|
||||||
|
|
||||||
|
// flush
|
||||||
|
wbuf.startMsg('H')
|
||||||
|
wbuf.closeMsg()
|
||||||
|
wbuf.closeMsg()
|
||||||
|
|
||||||
|
_, err = c.conn.Write(wbuf.buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
var t byte
|
||||||
|
var r *msgReader
|
||||||
|
t, r, err := c.rxMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case closeComplete:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
err = c.processContextFreeMsg(t, r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen establishes a PostgreSQL listen/notify to channel
|
// Listen establishes a PostgreSQL listen/notify to channel
|
||||||
@@ -400,24 +432,27 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) sendSimpleQuery(sql string, arguments ...interface{}) (err error) {
|
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
||||||
if len(arguments) > 0 {
|
if len(args) == 0 {
|
||||||
sql, err = SanitizeSql(sql, arguments...)
|
wbuf := newWriteBuf(c.wbuf[0:0], 'Q')
|
||||||
|
wbuf.WriteCString(sql)
|
||||||
|
wbuf.closeMsg()
|
||||||
|
|
||||||
|
_, err := c.conn.Write(wbuf.buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
c.die(err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
wbuf := newWriteBuf(c.wbuf[0:0], 'Q')
|
ps, err := c.Prepare("", sql)
|
||||||
wbuf.WriteCString(sql)
|
|
||||||
wbuf.closeMsg()
|
|
||||||
|
|
||||||
_, err = c.conn.Write(wbuf.buf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.die(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return c.sendPreparedQuery(ps, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) {
|
func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) {
|
||||||
@@ -433,10 +468,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
||||||
for i, oid := range ps.ParameterOids {
|
for i, oid := range ps.ParameterOids {
|
||||||
switch arg := arguments[i].(type) {
|
switch arg := arguments[i].(type) {
|
||||||
case BinaryEncoder:
|
case Encoder:
|
||||||
wbuf.WriteInt16(BinaryFormatCode)
|
wbuf.WriteInt16(arg.FormatCode())
|
||||||
case TextEncoder:
|
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
|
||||||
default:
|
default:
|
||||||
switch oid {
|
switch oid {
|
||||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid:
|
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid:
|
||||||
@@ -457,18 +490,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch arg := arguments[i].(type) {
|
switch arg := arguments[i].(type) {
|
||||||
case BinaryEncoder:
|
case Encoder:
|
||||||
err = arg.EncodeBinary(wbuf, oid)
|
err = arg.Encode(wbuf, oid)
|
||||||
case TextEncoder:
|
|
||||||
var s string
|
|
||||||
var status byte
|
|
||||||
s, status, err = arg.EncodeText()
|
|
||||||
if status == NullText {
|
|
||||||
wbuf.WriteInt32(-1)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
wbuf.WriteInt32(int32(len(s)))
|
|
||||||
wbuf.WriteBytes([]byte(s))
|
|
||||||
default:
|
default:
|
||||||
switch oid {
|
switch oid {
|
||||||
case BoolOid:
|
case BoolOid:
|
||||||
|
|||||||
@@ -321,6 +321,35 @@ func TestPrepare(t *testing.T) {
|
|||||||
if s != "hello" {
|
if s != "hello" {
|
||||||
t.Errorf("Prepared statement did not return expected value: %v", s)
|
t.Errorf("Prepared statement did not return expected value: %v", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = conn.Deallocate("test")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("conn.Deallocate failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create another prepared statement to ensure Deallocate left the connection
|
||||||
|
// in a working state and that we can reuse the prepared statement name.
|
||||||
|
|
||||||
|
_, err = conn.Prepare("test", "select $1::integer")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unable to prepare statement: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int32
|
||||||
|
err = conn.QueryRow("test", int32(1)).Scan(&n)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Executing prepared statement failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != 1 {
|
||||||
|
t.Errorf("Prepared statement did not return expected value: %v", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.Deallocate("test")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("conn.Deallocate failed: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrepareFailure(t *testing.T) {
|
func TestPrepareFailure(t *testing.T) {
|
||||||
|
|||||||
@@ -57,12 +57,19 @@ func (p *NullPoint) Scan(vr *pgx.ValueReader) error {
|
|||||||
return vr.Err()
|
return vr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p NullPoint) EncodeText() (string, byte, error) {
|
func (p NullPoint) FormatCode() int16 { return pgx.BinaryFormatCode }
|
||||||
if p.Valid {
|
|
||||||
return fmt.Sprintf("point(%v,%v)", p.X, p.Y), pgx.SafeText, nil
|
func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error {
|
||||||
} else {
|
if !p.Valid {
|
||||||
return "", pgx.NullText, nil
|
w.WriteInt32(-1)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s := fmt.Sprintf("point(%v,%v)", p.X, p.Y)
|
||||||
|
w.WriteInt32(int32(len(s)))
|
||||||
|
w.WriteBytes([]byte(s))
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p NullPoint) String() string {
|
func (p NullPoint) String() string {
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ const (
|
|||||||
bindComplete = '2'
|
bindComplete = '2'
|
||||||
notificationResponse = 'A'
|
notificationResponse = 'A'
|
||||||
noData = 'n'
|
noData = 'n'
|
||||||
|
closeComplete = '3'
|
||||||
|
flush = 'H'
|
||||||
)
|
)
|
||||||
|
|
||||||
type startupMessage struct {
|
type startupMessage struct {
|
||||||
|
|||||||
@@ -317,42 +317,22 @@ func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
|
|||||||
c.rows = Rows{conn: c}
|
c.rows = Rows{conn: c}
|
||||||
rows := &c.rows
|
rows := &c.rows
|
||||||
|
|
||||||
if ps, present := c.preparedStatements[sql]; present {
|
ps, ok := c.preparedStatements[sql]
|
||||||
rows.fields = ps.FieldDescriptions
|
if !ok {
|
||||||
err := c.sendPreparedQuery(ps, args...)
|
var err error
|
||||||
|
ps, err = c.Prepare("", sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.abort(err)
|
rows.abort(err)
|
||||||
}
|
|
||||||
return rows, rows.err
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.sendSimpleQuery(sql, args...)
|
|
||||||
if err != nil {
|
|
||||||
rows.abort(err)
|
|
||||||
return rows, rows.err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simple queries don't know the field descriptions of the result.
|
|
||||||
// Read until that is known before returning
|
|
||||||
for {
|
|
||||||
t, r, err := c.rxMsg()
|
|
||||||
if err != nil {
|
|
||||||
rows.Fatal(err)
|
|
||||||
return rows, rows.err
|
return rows, rows.err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch t {
|
|
||||||
case rowDescription:
|
|
||||||
rows.fields = rows.conn.rxRowDescription(r)
|
|
||||||
return rows, nil
|
|
||||||
default:
|
|
||||||
err = rows.conn.processContextFreeMsg(t, r)
|
|
||||||
if err != nil {
|
|
||||||
rows.Fatal(err)
|
|
||||||
return rows, rows.err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rows.fields = ps.FieldDescriptions
|
||||||
|
err := c.sendPreparedQuery(ps, args...)
|
||||||
|
if err != nil {
|
||||||
|
rows.abort(err)
|
||||||
|
}
|
||||||
|
return rows, rows.err
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryRow is a convenience wrapper over Query. Any error that occurs while
|
// QueryRow is a convenience wrapper over Query. Any error that occurs while
|
||||||
|
|||||||
+11
-7
@@ -331,12 +331,16 @@ func TestQueryPreparedEncodeError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that an argument that implements TextEncoder, but not BinaryEncoder
|
// Ensure that an argument that implements Encoder works when the parameter type
|
||||||
// works when the parameter type is a core type.
|
// is a core type.
|
||||||
type coreTextEncoder struct{}
|
type coreEncoder struct{}
|
||||||
|
|
||||||
func (n *coreTextEncoder) EncodeText() (string, byte, error) {
|
func (n coreEncoder) FormatCode() int16 { return pgx.TextFormatCode }
|
||||||
return "42", pgx.SafeText, nil
|
|
||||||
|
func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.Oid) error {
|
||||||
|
w.WriteInt32(int32(2))
|
||||||
|
w.WriteBytes([]byte("42"))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryPreparedEncodeCoreTextFormatError(t *testing.T) {
|
func TestQueryPreparedEncodeCoreTextFormatError(t *testing.T) {
|
||||||
@@ -348,7 +352,7 @@ func TestQueryPreparedEncodeCoreTextFormatError(t *testing.T) {
|
|||||||
mustPrepare(t, conn, "testTranscode", "select $1::integer")
|
mustPrepare(t, conn, "testTranscode", "select $1::integer")
|
||||||
|
|
||||||
var n int32
|
var n int32
|
||||||
err := conn.QueryRow("testTranscode", &coreTextEncoder{}).Scan(&n)
|
err := conn.QueryRow("testTranscode", &coreEncoder{}).Scan(&n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected conn.QueryRow error: %v", err)
|
t.Fatalf("Unexpected conn.QueryRow error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -473,7 +477,7 @@ func TestQueryRowUnpreparedErrors(t *testing.T) {
|
|||||||
scanArgs []interface{}
|
scanArgs []interface{}
|
||||||
err string
|
err string
|
||||||
}{
|
}{
|
||||||
{"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Expected type oid 21 but received type oid 705"},
|
{"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "could not determine data type of parameter $1 (SQLSTATE 42P18)"},
|
||||||
{"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
|
{"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
|
||||||
{"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
|
{"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
@@ -33,13 +31,6 @@ const (
|
|||||||
BinaryFormatCode = 1
|
BinaryFormatCode = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
// EncodeText statuses
|
|
||||||
const (
|
|
||||||
NullText = iota
|
|
||||||
SafeText = iota
|
|
||||||
UnsafeText = iota
|
|
||||||
)
|
|
||||||
|
|
||||||
type SerializationError string
|
type SerializationError string
|
||||||
|
|
||||||
func (e SerializationError) Error() string {
|
func (e SerializationError) Error() string {
|
||||||
@@ -49,32 +40,28 @@ func (e SerializationError) Error() string {
|
|||||||
// Scanner is an interface used to decode values from the PostgreSQL server.
|
// Scanner is an interface used to decode values from the PostgreSQL server.
|
||||||
type Scanner interface {
|
type Scanner interface {
|
||||||
// Scan MUST check r.Type().DataType and r.Type().FormatCode before decoding.
|
// Scan MUST check r.Type().DataType and r.Type().FormatCode before decoding.
|
||||||
// It should not assume that it was called on the type of value.
|
// It should not assume that it was called on a data type or format that it
|
||||||
|
// understands.
|
||||||
Scan(r *ValueReader) error
|
Scan(r *ValueReader) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// TextEncoder is an interface used to encode values in text format for
|
// Encoder is an interface used to encode values for transmission to the
|
||||||
// transmission to the PostgreSQL server. It is used by unprepared
|
// PostgreSQL server.
|
||||||
// queries and for prepared queries when the type does not implement
|
type Encoder interface {
|
||||||
// BinaryEncoder
|
// Encode writes the value to w.
|
||||||
type TextEncoder interface {
|
|
||||||
// EncodeText returns the value encoded into a string. status must be
|
|
||||||
// NullText if the value is NULL, UnsafeText if the value should be quoted
|
|
||||||
// and escaped, or SafeText if the value should not be quoted.
|
|
||||||
EncodeText() (val string, status byte, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BinaryEncoder is an interface used to encode values in binary format for
|
|
||||||
// transmission to the PostgreSQL server. It is used by prepared queries.
|
|
||||||
type BinaryEncoder interface {
|
|
||||||
// EncodeBinary writes the binary value to w.
|
|
||||||
//
|
//
|
||||||
// EncodeBinary MUST check oid to see if the parameter data type is
|
// If the value is NULL an int32(-1) should be written.
|
||||||
// compatible. If this is not done, the PostgreSQL server may detect the
|
//
|
||||||
// error if the expected data size or format of the encoded data does not
|
// Encode MUST check oid to see if the parameter data type is compatible. If
|
||||||
// match. But if the encoded data is a valid representation of the data type
|
// this is not done, the PostgreSQL server may detect the error if the
|
||||||
// PostgreSQL expects such as date and int4, incorrect data may be stored.
|
// expected data size or format of the encoded data does not match. But if
|
||||||
EncodeBinary(w *WriteBuf, oid Oid) error
|
// the encoded data is a valid representation of the data type PostgreSQL
|
||||||
|
// expects such as date and int4, incorrect data may be stored.
|
||||||
|
Encode(w *WriteBuf, oid Oid) error
|
||||||
|
|
||||||
|
// FormatCode returns the format that the encoder writes the value. It must be
|
||||||
|
// either pgx.TextFormatCode or pgx.BinaryFormatCode.
|
||||||
|
FormatCode() int16
|
||||||
}
|
}
|
||||||
|
|
||||||
// NullFloat32 represents an float4 that may be null.
|
// NullFloat32 represents an float4 that may be null.
|
||||||
@@ -102,17 +89,11 @@ func (n *NullFloat32) Scan(vr *ValueReader) error {
|
|||||||
return vr.Err()
|
return vr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n NullFloat32) EncodeText() (string, byte, error) {
|
func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode }
|
||||||
if n.Valid {
|
|
||||||
return strconv.FormatFloat(float64(n.Float32), 'f', -1, 32), SafeText, nil
|
|
||||||
} else {
|
|
||||||
return "", NullText, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n NullFloat32) EncodeBinary(w *WriteBuf, oid Oid) error {
|
func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error {
|
||||||
if oid != Float4Oid {
|
if oid != Float4Oid {
|
||||||
return SerializationError(fmt.Sprintf("NullFloat32.EncodeBinary cannot encode into OID %d", oid))
|
return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into OID %d", oid))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
@@ -148,15 +129,9 @@ func (n *NullFloat64) Scan(vr *ValueReader) error {
|
|||||||
return vr.Err()
|
return vr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n NullFloat64) EncodeText() (string, byte, error) {
|
func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode }
|
||||||
if n.Valid {
|
|
||||||
return strconv.FormatFloat(n.Float64, 'f', -1, 64), SafeText, nil
|
|
||||||
} else {
|
|
||||||
return "", NullText, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n NullFloat64) EncodeBinary(w *WriteBuf, oid Oid) error {
|
func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error {
|
||||||
if oid != Float8Oid {
|
if oid != Float8Oid {
|
||||||
return SerializationError(fmt.Sprintf("NullFloat64.EncodeBinary cannot encode into OID %d", oid))
|
return SerializationError(fmt.Sprintf("NullFloat64.EncodeBinary cannot encode into OID %d", oid))
|
||||||
}
|
}
|
||||||
@@ -193,12 +168,15 @@ func (s *NullString) Scan(vr *ValueReader) error {
|
|||||||
return vr.Err()
|
return vr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s NullString) EncodeText() (string, byte, error) {
|
func (n NullString) FormatCode() int16 { return TextFormatCode }
|
||||||
if s.Valid {
|
|
||||||
return s.String, UnsafeText, nil
|
func (s NullString) Encode(w *WriteBuf, oid Oid) error {
|
||||||
} else {
|
if !s.Valid {
|
||||||
return "", NullText, nil
|
w.WriteInt32(-1)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return encodeText(w, s.String)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NullInt16 represents an smallint that may be null.
|
// NullInt16 represents an smallint that may be null.
|
||||||
@@ -226,17 +204,11 @@ func (n *NullInt16) Scan(vr *ValueReader) error {
|
|||||||
return vr.Err()
|
return vr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n NullInt16) EncodeText() (string, byte, error) {
|
func (n NullInt16) FormatCode() int16 { return BinaryFormatCode }
|
||||||
if n.Valid {
|
|
||||||
return strconv.FormatInt(int64(n.Int16), 10), SafeText, nil
|
|
||||||
} else {
|
|
||||||
return "", NullText, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n NullInt16) EncodeBinary(w *WriteBuf, oid Oid) error {
|
func (n NullInt16) Encode(w *WriteBuf, oid Oid) error {
|
||||||
if oid != Int2Oid {
|
if oid != Int2Oid {
|
||||||
return SerializationError(fmt.Sprintf("NullInt16.EncodeBinary cannot encode into OID %d", oid))
|
return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into OID %d", oid))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
@@ -272,17 +244,11 @@ func (n *NullInt32) Scan(vr *ValueReader) error {
|
|||||||
return vr.Err()
|
return vr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n NullInt32) EncodeText() (string, byte, error) {
|
func (n NullInt32) FormatCode() int16 { return BinaryFormatCode }
|
||||||
if n.Valid {
|
|
||||||
return strconv.FormatInt(int64(n.Int32), 10), SafeText, nil
|
|
||||||
} else {
|
|
||||||
return "", NullText, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n NullInt32) EncodeBinary(w *WriteBuf, oid Oid) error {
|
func (n NullInt32) Encode(w *WriteBuf, oid Oid) error {
|
||||||
if oid != Int4Oid {
|
if oid != Int4Oid {
|
||||||
return SerializationError(fmt.Sprintf("NullInt32.EncodeBinary cannot encode into OID %d", oid))
|
return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into OID %d", oid))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
@@ -318,17 +284,11 @@ func (n *NullInt64) Scan(vr *ValueReader) error {
|
|||||||
return vr.Err()
|
return vr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n NullInt64) EncodeText() (string, byte, error) {
|
func (n NullInt64) FormatCode() int16 { return BinaryFormatCode }
|
||||||
if n.Valid {
|
|
||||||
return strconv.FormatInt(int64(n.Int64), 10), SafeText, nil
|
|
||||||
} else {
|
|
||||||
return "", NullText, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n NullInt64) EncodeBinary(w *WriteBuf, oid Oid) error {
|
func (n NullInt64) Encode(w *WriteBuf, oid Oid) error {
|
||||||
if oid != Int8Oid {
|
if oid != Int8Oid {
|
||||||
return SerializationError(fmt.Sprintf("NullInt64.EncodeBinary cannot encode into OID %d", oid))
|
return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into OID %d", oid))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
@@ -364,17 +324,11 @@ func (n *NullBool) Scan(vr *ValueReader) error {
|
|||||||
return vr.Err()
|
return vr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n NullBool) EncodeText() (string, byte, error) {
|
func (n NullBool) FormatCode() int16 { return BinaryFormatCode }
|
||||||
if n.Valid {
|
|
||||||
return strconv.FormatBool(n.Bool), SafeText, nil
|
|
||||||
} else {
|
|
||||||
return "", NullText, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n NullBool) EncodeBinary(w *WriteBuf, oid Oid) error {
|
func (n NullBool) Encode(w *WriteBuf, oid Oid) error {
|
||||||
if oid != BoolOid {
|
if oid != BoolOid {
|
||||||
return SerializationError(fmt.Sprintf("NullBool.EncodeBinary cannot encode into OID %d", oid))
|
return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into OID %d", oid))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
@@ -412,17 +366,11 @@ func (n *NullTime) Scan(vr *ValueReader) error {
|
|||||||
return vr.Err()
|
return vr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n NullTime) EncodeText() (string, byte, error) {
|
func (n NullTime) FormatCode() int16 { return BinaryFormatCode }
|
||||||
if n.Valid {
|
|
||||||
return n.Time.Format("2006-01-02 15:04:05.999999 -0700"), UnsafeText, nil
|
|
||||||
} else {
|
|
||||||
return "", NullText, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n NullTime) EncodeBinary(w *WriteBuf, oid Oid) error {
|
func (n NullTime) Encode(w *WriteBuf, oid Oid) error {
|
||||||
if oid != TimestampTzOid {
|
if oid != TimestampTzOid {
|
||||||
return SerializationError(fmt.Sprintf("NullTime.EncodeBinary cannot encode into OID %d", oid))
|
return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into OID %d", oid))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
@@ -433,99 +381,6 @@ func (n NullTime) EncodeBinary(w *WriteBuf, oid Oid) error {
|
|||||||
return encodeTimestampTz(w, n.Time)
|
return encodeTimestampTz(w, n.Time)
|
||||||
}
|
}
|
||||||
|
|
||||||
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
|
||||||
|
|
||||||
// QuoteString escapes and quotes a string making it safe for interpolation
|
|
||||||
// into an SQL string.
|
|
||||||
func QuoteString(input string) (output string) {
|
|
||||||
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// QuoteIdentifier escapes and quotes an identifier making it safe for
|
|
||||||
// interpolation into an SQL string
|
|
||||||
func QuoteIdentifier(input string) (output string) {
|
|
||||||
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
|
|
||||||
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
|
|
||||||
// appropriate.
|
|
||||||
func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
|
||||||
replacer := func(match string) (replacement string) {
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
|
||||||
if int(n-1) >= len(args) {
|
|
||||||
err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var s string
|
|
||||||
s, err = sanitizeArg(args[n-1])
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func sanitizeArg(arg interface{}) (string, error) {
|
|
||||||
switch arg := arg.(type) {
|
|
||||||
case string:
|
|
||||||
return QuoteString(arg), nil
|
|
||||||
case int:
|
|
||||||
return strconv.FormatInt(int64(arg), 10), nil
|
|
||||||
case int8:
|
|
||||||
return strconv.FormatInt(int64(arg), 10), nil
|
|
||||||
case int16:
|
|
||||||
return strconv.FormatInt(int64(arg), 10), nil
|
|
||||||
case int32:
|
|
||||||
return strconv.FormatInt(int64(arg), 10), nil
|
|
||||||
case int64:
|
|
||||||
return strconv.FormatInt(int64(arg), 10), nil
|
|
||||||
case time.Time:
|
|
||||||
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")), nil
|
|
||||||
case uint:
|
|
||||||
return strconv.FormatUint(uint64(arg), 10), nil
|
|
||||||
case uint8:
|
|
||||||
return strconv.FormatUint(uint64(arg), 10), nil
|
|
||||||
case uint16:
|
|
||||||
return strconv.FormatUint(uint64(arg), 10), nil
|
|
||||||
case uint32:
|
|
||||||
return strconv.FormatUint(uint64(arg), 10), nil
|
|
||||||
case uint64:
|
|
||||||
return strconv.FormatUint(uint64(arg), 10), nil
|
|
||||||
case float32:
|
|
||||||
return strconv.FormatFloat(float64(arg), 'f', -1, 32), nil
|
|
||||||
case float64:
|
|
||||||
return strconv.FormatFloat(arg, 'f', -1, 64), nil
|
|
||||||
case bool:
|
|
||||||
return strconv.FormatBool(arg), nil
|
|
||||||
case []byte:
|
|
||||||
return `E'\\x` + hex.EncodeToString(arg) + `'`, nil
|
|
||||||
case nil:
|
|
||||||
return "null", nil
|
|
||||||
case TextEncoder:
|
|
||||||
s, status, err := arg.EncodeText()
|
|
||||||
switch status {
|
|
||||||
case NullText:
|
|
||||||
return "null", err
|
|
||||||
case UnsafeText:
|
|
||||||
return QuoteString(s), err
|
|
||||||
case SafeText:
|
|
||||||
return s, err
|
|
||||||
default:
|
|
||||||
return "", SerializationError("Received invalid status from EncodeText")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return "", SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeBool(vr *ValueReader) bool {
|
func decodeBool(vr *ValueReader) bool {
|
||||||
if vr.Len() == -1 {
|
if vr.Len() == -1 {
|
||||||
vr.Fatal(ProtocolError("Cannot decode null into bool"))
|
vr.Fatal(ProtocolError("Cannot decode null into bool"))
|
||||||
|
|||||||
@@ -8,77 +8,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestQuoteString(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
if pgx.QuoteString("test") != "'test'" {
|
|
||||||
t.Error("Failed to quote string")
|
|
||||||
}
|
|
||||||
|
|
||||||
if pgx.QuoteString("Jack's") != "'Jack''s'" {
|
|
||||||
t.Error("Failed to quote and escape string with embedded quote")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSanitizeSql(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
successTests := []struct {
|
|
||||||
sql string
|
|
||||||
args []interface{}
|
|
||||||
output string
|
|
||||||
}{
|
|
||||||
{"select $1", []interface{}{nil}, "select null"},
|
|
||||||
{"select $1", []interface{}{"Jack's"}, "select 'Jack''s'"},
|
|
||||||
{"select $1", []interface{}{int(42)}, "select 42"},
|
|
||||||
{"select $1", []interface{}{uint(42)}, "select 42"},
|
|
||||||
{"select $1", []interface{}{int8(42)}, "select 42"},
|
|
||||||
{"select $1", []interface{}{int16(42)}, "select 42"},
|
|
||||||
{"select $1", []interface{}{int32(42)}, "select 42"},
|
|
||||||
{"select $1", []interface{}{int64(42)}, "select 42"},
|
|
||||||
{"select $1", []interface{}{uint8(42)}, "select 42"},
|
|
||||||
{"select $1", []interface{}{uint16(42)}, "select 42"},
|
|
||||||
{"select $1", []interface{}{uint32(42)}, "select 42"},
|
|
||||||
{"select $1", []interface{}{uint64(42)}, "select 42"},
|
|
||||||
{"select $1", []interface{}{float32(1.23)}, "select 1.23"},
|
|
||||||
{"select $1", []interface{}{float64(1.23)}, "select 1.23"},
|
|
||||||
{"select $1", []interface{}{true}, "select true"},
|
|
||||||
{"select $1, $2, $3", []interface{}{"Jack's", 42, 1.23}, "select 'Jack''s', 42, 1.23"},
|
|
||||||
{"select $1", []interface{}{[]byte{0, 15, 255, 17}}, `select E'\\x000fff11'`},
|
|
||||||
{"select $1", []interface{}{&pgx.NullInt64{Int64: 0, Valid: false}}, "select null"},
|
|
||||||
{"select $1", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, "select 1"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range successTests {
|
|
||||||
san, err := pgx.SanitizeSql(tt.sql, tt.args...)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%d. Unexpected failure: %v (sql -> %v, args -> %v)", i, err, tt.sql, tt.args)
|
|
||||||
}
|
|
||||||
if san != tt.output {
|
|
||||||
t.Errorf("%d. Expected %v, got %v (sql -> %v, args -> %v)", i, tt.output, san, tt.sql, tt.args)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
errorTests := []struct {
|
|
||||||
sql string
|
|
||||||
args []interface{}
|
|
||||||
err string
|
|
||||||
}{
|
|
||||||
{"select $1", []interface{}{t}, "is not a core type and it does not implement TextEncoder"},
|
|
||||||
{"select $1, $2", []interface{}{}, "Cannot interpolate $1, only 0 arguments provided"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range errorTests {
|
|
||||||
_, err := pgx.SanitizeSql(tt.sql, tt.args...)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("%d. Unexpected success (sql -> %v, args -> %v)", i, tt.sql, tt.args, err)
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), tt.err) {
|
|
||||||
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, args -> %v)", i, tt.err, err, tt.sql, tt.args)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDateTranscode(t *testing.T) {
|
func TestDateTranscode(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user