2
0

Don't panic!

This commit is contained in:
Jack Christensen
2013-07-20 13:07:30 -05:00
parent 0c3753e507
commit 36904168b2
12 changed files with 109 additions and 84 deletions
+22 -25
View File
@@ -1,7 +1,6 @@
package pgx_test package pgx_test
import ( import (
"fmt"
"github.com/JackC/pgx" "github.com/JackC/pgx"
"math/rand" "math/rand"
"testing" "testing"
@@ -22,7 +21,7 @@ func createNarrowTestData(b *testing.B, conn *pgx.Connection) {
return return
} }
if _, err := conn.Execute(` mustExecute(b, conn, `
drop table if exists narrow; drop table if exists narrow;
create table narrow( create table narrow(
@@ -38,9 +37,7 @@ func createNarrowTestData(b *testing.B, conn *pgx.Connection) {
from generate_series(1, 10000); from generate_series(1, 10000);
analyze narrow; analyze narrow;
`); err != nil { `)
panic(fmt.Sprintf("Unable to create narrow test data: %v", err))
}
mustPrepare(b, conn, "getNarrowById", "select * from narrow where id=$1") mustPrepare(b, conn, "getNarrowById", "select * from narrow where id=$1")
mustPrepare(b, conn, "getMultipleNarrowById", "select * from narrow where id between $1 and $2") mustPrepare(b, conn, "getMultipleNarrowById", "select * from narrow where id between $1 and $2")
@@ -64,7 +61,7 @@ func restoreBinaryEncoders(encoders map[pgx.Oid]func(*pgx.MessageReader, int32)
} }
func BenchmarkSelectRowSimpleNarrow(b *testing.B) { func BenchmarkSelectRowSimpleNarrow(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createNarrowTestData(b, conn) createNarrowTestData(b, conn)
// Get random ids outside of timing // Get random ids outside of timing
@@ -80,7 +77,7 @@ func BenchmarkSelectRowSimpleNarrow(b *testing.B) {
} }
func BenchmarkSelectRowPreparedNarrow(b *testing.B) { func BenchmarkSelectRowPreparedNarrow(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createNarrowTestData(b, conn) createNarrowTestData(b, conn)
// Get random ids outside of timing // Get random ids outside of timing
@@ -96,7 +93,7 @@ func BenchmarkSelectRowPreparedNarrow(b *testing.B) {
} }
func BenchmarkSelectRowsSimpleNarrow(b *testing.B) { func BenchmarkSelectRowsSimpleNarrow(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createNarrowTestData(b, conn) createNarrowTestData(b, conn)
// Get random ids outside of timing // Get random ids outside of timing
@@ -112,7 +109,7 @@ func BenchmarkSelectRowsSimpleNarrow(b *testing.B) {
} }
func BenchmarkSelectRowsPreparedNarrow(b *testing.B) { func BenchmarkSelectRowsPreparedNarrow(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createNarrowTestData(b, conn) createNarrowTestData(b, conn)
// Get random ids outside of timing // Get random ids outside of timing
@@ -199,7 +196,7 @@ func createJoinsTestData(b *testing.B, conn *pgx.Connection) {
} }
func BenchmarkSelectRowsSimpleJoins(b *testing.B) { func BenchmarkSelectRowsSimpleJoins(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createJoinsTestData(b, conn) createJoinsTestData(b, conn)
sql := ` sql := `
@@ -219,7 +216,7 @@ func BenchmarkSelectRowsSimpleJoins(b *testing.B) {
} }
func BenchmarkSelectRowsPreparedJoins(b *testing.B) { func BenchmarkSelectRowsPreparedJoins(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createJoinsTestData(b, conn) createJoinsTestData(b, conn)
b.ResetTimer() b.ResetTimer()
@@ -254,7 +251,7 @@ func createInt2TextVsBinaryTestData(b *testing.B, conn *pgx.Connection) {
} }
func BenchmarkInt2Text(b *testing.B) { func BenchmarkInt2Text(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createInt2TextVsBinaryTestData(b, conn) createInt2TextVsBinaryTestData(b, conn)
encoders := removeBinaryEncoders() encoders := removeBinaryEncoders()
@@ -270,7 +267,7 @@ func BenchmarkInt2Text(b *testing.B) {
} }
func BenchmarkInt2Binary(b *testing.B) { func BenchmarkInt2Binary(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createInt2TextVsBinaryTestData(b, conn) createInt2TextVsBinaryTestData(b, conn)
mustPrepare(b, conn, "selectInt16", "select * from t") mustPrepare(b, conn, "selectInt16", "select * from t")
defer func() { conn.Deallocate("selectInt16") }() defer func() { conn.Deallocate("selectInt16") }()
@@ -307,7 +304,7 @@ func createInt4TextVsBinaryTestData(b *testing.B, conn *pgx.Connection) {
} }
func BenchmarkInt4Text(b *testing.B) { func BenchmarkInt4Text(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createInt4TextVsBinaryTestData(b, conn) createInt4TextVsBinaryTestData(b, conn)
encoders := removeBinaryEncoders() encoders := removeBinaryEncoders()
@@ -323,7 +320,7 @@ func BenchmarkInt4Text(b *testing.B) {
} }
func BenchmarkInt4Binary(b *testing.B) { func BenchmarkInt4Binary(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createInt4TextVsBinaryTestData(b, conn) createInt4TextVsBinaryTestData(b, conn)
mustPrepare(b, conn, "selectInt32", "select * from t") mustPrepare(b, conn, "selectInt32", "select * from t")
defer func() { conn.Deallocate("selectInt32") }() defer func() { conn.Deallocate("selectInt32") }()
@@ -360,7 +357,7 @@ func createInt8TextVsBinaryTestData(b *testing.B, conn *pgx.Connection) {
} }
func BenchmarkInt8Text(b *testing.B) { func BenchmarkInt8Text(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createInt8TextVsBinaryTestData(b, conn) createInt8TextVsBinaryTestData(b, conn)
encoders := removeBinaryEncoders() encoders := removeBinaryEncoders()
@@ -376,7 +373,7 @@ func BenchmarkInt8Text(b *testing.B) {
} }
func BenchmarkInt8Binary(b *testing.B) { func BenchmarkInt8Binary(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createInt8TextVsBinaryTestData(b, conn) createInt8TextVsBinaryTestData(b, conn)
mustPrepare(b, conn, "selectInt64", "select * from t") mustPrepare(b, conn, "selectInt64", "select * from t")
defer func() { conn.Deallocate("selectInt64") }() defer func() { conn.Deallocate("selectInt64") }()
@@ -413,7 +410,7 @@ func createFloat4TextVsBinaryTestData(b *testing.B, conn *pgx.Connection) {
} }
func BenchmarkFloat4Text(b *testing.B) { func BenchmarkFloat4Text(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createFloat4TextVsBinaryTestData(b, conn) createFloat4TextVsBinaryTestData(b, conn)
encoders := removeBinaryEncoders() encoders := removeBinaryEncoders()
@@ -429,7 +426,7 @@ func BenchmarkFloat4Text(b *testing.B) {
} }
func BenchmarkFloat4Binary(b *testing.B) { func BenchmarkFloat4Binary(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createFloat4TextVsBinaryTestData(b, conn) createFloat4TextVsBinaryTestData(b, conn)
mustPrepare(b, conn, "selectFloat32", "select * from t") mustPrepare(b, conn, "selectFloat32", "select * from t")
defer func() { conn.Deallocate("selectFloat32") }() defer func() { conn.Deallocate("selectFloat32") }()
@@ -466,7 +463,7 @@ func createFloat8TextVsBinaryTestData(b *testing.B, conn *pgx.Connection) {
} }
func BenchmarkFloat8Text(b *testing.B) { func BenchmarkFloat8Text(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createFloat8TextVsBinaryTestData(b, conn) createFloat8TextVsBinaryTestData(b, conn)
encoders := removeBinaryEncoders() encoders := removeBinaryEncoders()
@@ -482,7 +479,7 @@ func BenchmarkFloat8Text(b *testing.B) {
} }
func BenchmarkFloat8Binary(b *testing.B) { func BenchmarkFloat8Binary(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createFloat8TextVsBinaryTestData(b, conn) createFloat8TextVsBinaryTestData(b, conn)
mustPrepare(b, conn, "selectFloat32", "select * from t") mustPrepare(b, conn, "selectFloat32", "select * from t")
defer func() { conn.Deallocate("selectFloat32") }() defer func() { conn.Deallocate("selectFloat32") }()
@@ -519,7 +516,7 @@ func createBoolTextVsBinaryTestData(b *testing.B, conn *pgx.Connection) {
} }
func BenchmarkBoolText(b *testing.B) { func BenchmarkBoolText(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createBoolTextVsBinaryTestData(b, conn) createBoolTextVsBinaryTestData(b, conn)
encoders := removeBinaryEncoders() encoders := removeBinaryEncoders()
@@ -535,7 +532,7 @@ func BenchmarkBoolText(b *testing.B) {
} }
func BenchmarkBoolBinary(b *testing.B) { func BenchmarkBoolBinary(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createBoolTextVsBinaryTestData(b, conn) createBoolTextVsBinaryTestData(b, conn)
mustPrepare(b, conn, "selectBool", "select * from t") mustPrepare(b, conn, "selectBool", "select * from t")
defer func() { conn.Deallocate("selectBool") }() defer func() { conn.Deallocate("selectBool") }()
@@ -576,7 +573,7 @@ func createTimestampTzTextVsBinaryTestData(b *testing.B, conn *pgx.Connection) {
} }
func BenchmarkTimestampTzText(b *testing.B) { func BenchmarkTimestampTzText(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createTimestampTzTextVsBinaryTestData(b, conn) createTimestampTzTextVsBinaryTestData(b, conn)
encoders := removeBinaryEncoders() encoders := removeBinaryEncoders()
@@ -592,7 +589,7 @@ func BenchmarkTimestampTzText(b *testing.B) {
} }
func BenchmarkTimestampTzBinary(b *testing.B) { func BenchmarkTimestampTzBinary(b *testing.B) {
conn := getSharedConnection() conn := getSharedConnection(b)
createTimestampTzTextVsBinaryTestData(b, conn) createTimestampTzTextVsBinaryTestData(b, conn)
mustPrepare(b, conn, "selectTimestampTz", "select * from t") mustPrepare(b, conn, "selectTimestampTz", "select * from t")
defer func() { conn.Deallocate("selectTimestampTz") }() defer func() { conn.Deallocate("selectTimestampTz") }()
+15 -2
View File
@@ -73,6 +73,12 @@ func (e UnexpectedColumnCountError) Error() string {
return fmt.Sprintf("Expected result to have %d column(s), instead it has %d", e.ExpectedCount, e.ActualCount) return fmt.Sprintf("Expected result to have %d column(s), instead it has %d", e.ExpectedCount, e.ActualCount)
} }
type ProtocolError string
func (e ProtocolError) Error() string {
return string(e)
}
// sharedBufferSize is the default number of bytes of work buffer per connection // sharedBufferSize is the default number of bytes of work buffer per connection
const sharedBufferSize = 1024 const sharedBufferSize = 1024
@@ -171,7 +177,11 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error
fields = c.rxRowDescription(r) fields = c.rxRowDescription(r)
case dataRow: case dataRow:
if err == nil { if err == nil {
err = onDataRow(newDataRowReader(r, fields)) var drr *DataRowReader
drr, err = newDataRowReader(r, fields)
if err == nil {
err = onDataRow(drr)
}
} }
case commandComplete: case commandComplete:
case bindComplete: case bindComplete:
@@ -393,7 +403,10 @@ func (c *Connection) sendQuery(sql string, arguments ...interface{}) (err error)
func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err error) { func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err error) {
if len(arguments) > 0 { if len(arguments) > 0 {
sql = c.SanitizeSql(sql, arguments...) sql, err = c.SanitizeSql(sql, arguments...)
if err != nil {
return
}
} }
buf := c.getBuf() buf := c.getBuf()
+4 -4
View File
@@ -6,10 +6,10 @@ import (
"testing" "testing"
) )
func createConnectionPool(maxConnections int) *pgx.ConnectionPool { func createConnectionPool(t *testing.T, maxConnections int) *pgx.ConnectionPool {
pool, err := pgx.NewConnectionPool(*defaultConnectionParameters, maxConnections) pool, err := pgx.NewConnectionPool(*defaultConnectionParameters, maxConnections)
if err != nil { if err != nil {
panic("Unable to create connection pool") t.Fatalf("Unable to create connection pool: %v", err)
} }
return pool return pool
} }
@@ -30,7 +30,7 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) {
maxConnections := 2 maxConnections := 2
incrementCount := int32(100) incrementCount := int32(100)
completeSync := make(chan int) completeSync := make(chan int)
pool := createConnectionPool(maxConnections) pool := createConnectionPool(t, maxConnections)
defer pool.Close() defer pool.Close()
acquireAll := func() (connections []*pgx.Connection) { acquireAll := func() (connections []*pgx.Connection) {
@@ -99,7 +99,7 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) {
} }
func TestPoolReleaseWithTransactions(t *testing.T) { func TestPoolReleaseWithTransactions(t *testing.T) {
pool := createConnectionPool(1) pool := createConnectionPool(t, 1)
defer pool.Close() defer pool.Close()
var err error var err error
+16 -10
View File
@@ -124,7 +124,7 @@ func TestConnectWithMD5Password(t *testing.T) {
} }
func TestExecute(t *testing.T) { func TestExecute(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
if results := mustExecute(t, conn, "create temporary table foo(id integer primary key);"); results != "CREATE TABLE" { if results := mustExecute(t, conn, "create temporary table foo(id integer primary key);"); results != "CREATE TABLE" {
t.Error("Unexpected results from Execute") t.Error("Unexpected results from Execute")
@@ -167,7 +167,7 @@ func TestExecuteFailure(t *testing.T) {
} }
func TestSelectFunc(t *testing.T) { func TestSelectFunc(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
var sum, rowCount int32 var sum, rowCount int32
onDataRow := func(r *pgx.DataRowReader) error { onDataRow := func(r *pgx.DataRowReader) error {
@@ -206,14 +206,18 @@ func TestSelectFuncFailure(t *testing.T) {
} }
func Example_connectionSelectFunc() { func Example_connectionSelectFunc() {
conn := getSharedConnection() conn, err := pgx.Connect(*defaultConnectionParameters)
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
}
onDataRow := func(r *pgx.DataRowReader) error { onDataRow := func(r *pgx.DataRowReader) error {
fmt.Println(r.ReadValue()) fmt.Println(r.ReadValue())
return nil return nil
} }
err := conn.SelectFunc("select generate_series(1,$1)", onDataRow, 5) err = conn.SelectFunc("select generate_series(1,$1)", onDataRow, 5)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
@@ -226,7 +230,7 @@ func Example_connectionSelectFunc() {
} }
func TestSelectRows(t *testing.T) { func TestSelectRows(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
rows := mustSelectRows(t, conn, "select $1 as name, null as position", "Jack") rows := mustSelectRows(t, conn, "select $1 as name, null as position", "Jack")
@@ -248,7 +252,7 @@ func TestSelectRows(t *testing.T) {
} }
func TestSelectRow(t *testing.T) { func TestSelectRow(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
row := mustSelectRow(t, conn, "select $1 as name, null as position", "Jack") row := mustSelectRow(t, conn, "select $1 as name, null as position", "Jack")
if row["name"] != "Jack" { if row["name"] != "Jack" {
@@ -275,7 +279,7 @@ func TestSelectRow(t *testing.T) {
} }
func TestConnectionSelectValue(t *testing.T) { func TestConnectionSelectValue(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
test := func(sql string, expected interface{}, arguments ...interface{}) { test := func(sql string, expected interface{}, arguments ...interface{}) {
v, err := conn.SelectValue(sql, arguments...) v, err := conn.SelectValue(sql, arguments...)
@@ -315,7 +319,7 @@ func TestConnectionSelectValue(t *testing.T) {
} }
func TestSelectValues(t *testing.T) { func TestSelectValues(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
test := func(sql string, expected []interface{}, arguments ...interface{}) { test := func(sql string, expected []interface{}, arguments ...interface{}) {
values, err := conn.SelectValues(sql, arguments...) values, err := conn.SelectValues(sql, arguments...)
@@ -406,7 +410,9 @@ func TestPrepare(t *testing.T) {
bytea[2] = 255 // 0xFF bytea[2] = 255 // 0xFF
bytea[3] = 17 // 0x11 bytea[3] = 17 // 0x11
if conn.SanitizeSql("select $1", bytea) != `select E'\\x000fff11'` { if sql, err := conn.SanitizeSql("select $1", bytea); err != nil {
t.Errorf("Error sanitizing []byte: %v", err)
} else if sql != `select E'\\x000fff11'` {
t.Error("Failed to sanitize []byte") t.Error("Failed to sanitize []byte")
} }
var result interface{} var result interface{}
@@ -579,7 +585,7 @@ func TestListenNotify(t *testing.T) {
t.Fatalf("Unable to start listening: %v", err) t.Fatalf("Unable to start listening: %v", err)
} }
notifier := getSharedConnection() notifier := getSharedConnection(t)
mustExecute(t, notifier, "notify chat") mustExecute(t, notifier, "notify chat")
// when notification is waiting on the socket to be read // when notification is waiting on the socket to be read
+7 -3
View File
@@ -1,5 +1,9 @@
package pgx package pgx
import (
"fmt"
)
// DataRowReader is used by SelectFunc to process incoming rows. // DataRowReader is used by SelectFunc to process incoming rows.
type DataRowReader struct { type DataRowReader struct {
mr *MessageReader mr *MessageReader
@@ -7,14 +11,14 @@ type DataRowReader struct {
currentFieldIdx int currentFieldIdx int
} }
func newDataRowReader(mr *MessageReader, fields []FieldDescription) (r *DataRowReader) { func newDataRowReader(mr *MessageReader, fields []FieldDescription) (r *DataRowReader, err error) {
r = new(DataRowReader) r = new(DataRowReader)
r.mr = mr r.mr = mr
r.fields = fields r.fields = fields
fieldCount := int(mr.ReadInt16()) fieldCount := int(mr.ReadInt16())
if fieldCount != len(fields) { if fieldCount != len(fields) {
panic("Row description field count and data row field count do not match") return nil, ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(fields), fieldCount))
} }
return return
@@ -34,7 +38,7 @@ func (r *DataRowReader) ReadValue() interface{} {
case 1: case 1:
return vt.DecodeBinary(r.mr, size) return vt.DecodeBinary(r.mr, size)
default: default:
panic("Unknown format") return ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fieldDescription.FormatCode))
} }
} else { } else {
return r.mr.ReadString(size) return r.mr.ReadString(size)
+1 -1
View File
@@ -6,7 +6,7 @@ import (
) )
func TestDataRowReaderReadValue(t *testing.T) { func TestDataRowReaderReadValue(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
test := func(sql string, expected interface{}) { test := func(sql string, expected interface{}) {
var v interface{} var v interface{}
+8 -4
View File
@@ -23,7 +23,11 @@ func Example_customValueTranscoder() {
DecodeText: decodePointFromText, DecodeText: decodePointFromText,
EncodeTo: encodePoint} EncodeTo: encodePoint}
conn := getSharedConnection() conn, err := pgx.Connect(*defaultConnectionParameters)
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
}
v, _ := conn.SelectValue("select point(1.5,2.5)") v, _ := conn.SelectValue("select point(1.5,2.5)")
fmt.Println(v) fmt.Println(v)
@@ -35,18 +39,18 @@ func decodePointFromText(mr *pgx.MessageReader, size int32) interface{} {
s := mr.ReadString(size) s := mr.ReadString(size)
match := pointRegexp.FindStringSubmatch(s) match := pointRegexp.FindStringSubmatch(s)
if match == nil { if match == nil {
panic(fmt.Sprintf("Received invalid point: %v", s)) return pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))
} }
var err error var err error
var p Point var p Point
p.x, err = strconv.ParseFloat(match[1], 64) p.x, err = strconv.ParseFloat(match[1], 64)
if err != nil { if err != nil {
panic(fmt.Sprintf("Received invalid point: %v", s)) return pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))
} }
p.y, err = strconv.ParseFloat(match[2], 64) p.y, err = strconv.ParseFloat(match[2], 64)
if err != nil { if err != nil {
panic(fmt.Sprintf("Received invalid point: %v", s)) return pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))
} }
return p return p
} }
+2 -2
View File
@@ -10,12 +10,12 @@ type test interface {
var sharedConnection *pgx.Connection var sharedConnection *pgx.Connection
func getSharedConnection() (c *pgx.Connection) { func getSharedConnection(t test) (c *pgx.Connection) {
if sharedConnection == nil { if sharedConnection == nil {
var err error var err error
sharedConnection, err = pgx.Connect(*defaultConnectionParameters) sharedConnection, err = pgx.Connect(*defaultConnectionParameters)
if err != nil { if err != nil {
panic("Unable to establish connection") t.Fatalf("Unable to establish connection: %v", err)
} }
} }
+4 -3
View File
@@ -2,7 +2,7 @@ package pgx
import ( import (
"encoding/hex" "encoding/hex"
"reflect" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@@ -28,7 +28,7 @@ func (c *Connection) QuoteIdentifier(input string) (output string) {
// SanitizeSql substitutely args positionaly into sql. Placeholder values are // SanitizeSql substitutely args positionaly into sql. Placeholder values are
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as // $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
// appropriate. // appropriate.
func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string) { func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string, err error) {
replacer := func(match string) (replacement string) { replacer := func(match string) (replacement string) {
n, _ := strconv.ParseInt(match[1:], 10, 0) n, _ := strconv.ParseInt(match[1:], 10, 0)
switch arg := args[n-1].(type) { switch arg := args[n-1].(type) {
@@ -63,7 +63,8 @@ func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string
case []byte: case []byte:
return `E'\\x` + hex.EncodeToString(arg) + `'` return `E'\\x` + hex.EncodeToString(arg) + `'`
default: default:
panic("Unable to sanitize type: " + reflect.TypeOf(arg).String()) err = fmt.Errorf("Unable to sanitize type: %T", arg)
return ""
} }
} }
+12 -12
View File
@@ -5,7 +5,7 @@ import (
) )
func TestQuoteString(t *testing.T) { func TestQuoteString(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
if conn.QuoteString("test") != "'test'" { if conn.QuoteString("test") != "'test'" {
t.Error("Failed to quote string") t.Error("Failed to quote string")
@@ -17,22 +17,22 @@ func TestQuoteString(t *testing.T) {
} }
func TestSanitizeSql(t *testing.T) { func TestSanitizeSql(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
if conn.SanitizeSql("select $1", "Jack's") != "select 'Jack''s'" { if san, err := conn.SanitizeSql("select $1", "Jack's"); err != nil || san != "select 'Jack''s'" {
t.Error("Failed to sanitize string") t.Errorf("Failed to sanitize string: %v - %v", san, err)
} }
if conn.SanitizeSql("select $1", 42) != "select 42" { if san, err := conn.SanitizeSql("select $1", 42); err != nil || san != "select 42" {
t.Error("Failed to pass through integer") t.Errorf("Failed to pass through integer: %v - %v", san, err)
} }
if conn.SanitizeSql("select $1", 1.23) != "select 1.23" { if san, err := conn.SanitizeSql("select $1", 1.23); err != nil || san != "select 1.23" {
t.Error("Failed to pass through float") t.Errorf("Failed to pass through float: %v - %v", san, err)
} }
if conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23) != "select 'Jack''s', 42, 1.23" { if san, err := conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23); err != nil || san != "select 'Jack''s', 42, 1.23" {
t.Error("Failed to sanitize multiple params") t.Errorf("Failed to sanitize multiple params: %v - %v", san, err)
} }
bytea := make([]byte, 4) bytea := make([]byte, 4)
@@ -41,7 +41,7 @@ func TestSanitizeSql(t *testing.T) {
bytea[2] = 255 // 0xFF bytea[2] = 255 // 0xFF
bytea[3] = 17 // 0x11 bytea[3] = 17 // 0x11
if conn.SanitizeSql("select $1", bytea) != `select E'\\x000fff11'` { if san, err := conn.SanitizeSql("select $1", bytea); err != nil || san != `select E'\\x000fff11'` {
t.Error("Failed to sanitize []byte") t.Errorf("Failed to sanitize []byte: %v - %v", san, err)
} }
} }
+16 -16
View File
@@ -112,13 +112,13 @@ func decodeBoolFromText(mr *MessageReader, size int32) interface{} {
case "f": case "f":
return false return false
default: default:
panic(fmt.Sprintf("Received invalid bool: %v", s)) return ProtocolError(fmt.Sprintf("Received invalid bool: %v", s))
} }
} }
func decodeBoolFromBinary(mr *MessageReader, size int32) interface{} { func decodeBoolFromBinary(mr *MessageReader, size int32) interface{} {
if size != 1 { if size != 1 {
panic("Received an invalid size for an bool") return ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", size))
} }
b := mr.ReadByte() b := mr.ReadByte()
return b != 0 return b != 0
@@ -138,14 +138,14 @@ func decodeInt8FromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size) s := mr.ReadString(size)
n, err := strconv.ParseInt(s, 10, 64) n, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
panic(fmt.Sprintf("Received invalid int8: %v", s)) return ProtocolError(fmt.Sprintf("Received invalid int8: %v", s))
} }
return n return n
} }
func decodeInt8FromBinary(mr *MessageReader, size int32) interface{} { func decodeInt8FromBinary(mr *MessageReader, size int32) interface{} {
if size != 8 { if size != 8 {
panic("Received an invalid size for an int8") return ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size))
} }
return mr.ReadInt64() return mr.ReadInt64()
} }
@@ -160,14 +160,14 @@ func decodeInt2FromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size) s := mr.ReadString(size)
n, err := strconv.ParseInt(s, 10, 16) n, err := strconv.ParseInt(s, 10, 16)
if err != nil { if err != nil {
panic(fmt.Sprintf("Received invalid int2: %v", s)) return ProtocolError(fmt.Sprintf("Received invalid int2: %v", s))
} }
return int16(n) return int16(n)
} }
func decodeInt2FromBinary(mr *MessageReader, size int32) interface{} { func decodeInt2FromBinary(mr *MessageReader, size int32) interface{} {
if size != 2 { if size != 2 {
panic("Received an invalid size for an int8") return ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", size))
} }
return mr.ReadInt16() return mr.ReadInt16()
} }
@@ -182,14 +182,14 @@ func decodeInt4FromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size) s := mr.ReadString(size)
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
panic(fmt.Sprintf("Received invalid int4: %v", s)) return ProtocolError(fmt.Sprintf("Received invalid int4: %v", s))
} }
return int32(n) return int32(n)
} }
func decodeInt4FromBinary(mr *MessageReader, size int32) interface{} { func decodeInt4FromBinary(mr *MessageReader, size int32) interface{} {
if size != 4 { if size != 4 {
panic("Received an invalid size for an int4") return ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", size))
} }
return mr.ReadInt32() return mr.ReadInt32()
} }
@@ -204,14 +204,14 @@ func decodeFloat4FromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size) s := mr.ReadString(size)
n, err := strconv.ParseFloat(s, 32) n, err := strconv.ParseFloat(s, 32)
if err != nil { if err != nil {
panic(fmt.Sprintf("Received invalid float4: %v", s)) return ProtocolError(fmt.Sprintf("Received invalid float4: %v", s))
} }
return float32(n) return float32(n)
} }
func decodeFloat4FromBinary(mr *MessageReader, size int32) interface{} { func decodeFloat4FromBinary(mr *MessageReader, size int32) interface{} {
if size != 4 { if size != 4 {
panic("Received an invalid size for an float4") return ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", size))
} }
i := mr.ReadInt32() i := mr.ReadInt32()
@@ -229,14 +229,14 @@ func decodeFloat8FromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size) s := mr.ReadString(size)
v, err := strconv.ParseFloat(s, 64) v, err := strconv.ParseFloat(s, 64)
if err != nil { if err != nil {
panic(fmt.Sprintf("Received invalid float8: %v", s)) return ProtocolError(fmt.Sprintf("Received invalid float8: %v", s))
} }
return v return v
} }
func decodeFloat8FromBinary(mr *MessageReader, size int32) interface{} { func decodeFloat8FromBinary(mr *MessageReader, size int32) interface{} {
if size != 8 { if size != 8 {
panic("Received an invalid size for an float4") return ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", size))
} }
i := mr.ReadInt64() i := mr.ReadInt64()
@@ -264,7 +264,7 @@ func decodeByteaFromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size) s := mr.ReadString(size)
b, err := hex.DecodeString(s[2:]) b, err := hex.DecodeString(s[2:])
if err != nil { if err != nil {
panic("Can't decode byte array") return ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s))
} }
return b return b
} }
@@ -279,7 +279,7 @@ func decodeDateFromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size) s := mr.ReadString(size)
t, err := time.ParseInLocation("2006-01-02", s, time.Local) t, err := time.ParseInLocation("2006-01-02", s, time.Local)
if err != nil { if err != nil {
panic("Can't decode date") return ProtocolError(fmt.Sprintf("Can't decode date: %v", s))
} }
return t return t
} }
@@ -295,14 +295,14 @@ func decodeTimestampTzFromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size) s := mr.ReadString(size)
t, err := time.Parse("2006-01-02 15:04:05.999999-07", s) t, err := time.Parse("2006-01-02 15:04:05.999999-07", s)
if err != nil { if err != nil {
panic(fmt.Sprintf("Can't decode timestamptz: %v", err)) return ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s))
} }
return t return t
} }
func decodeTimestampTzFromBinary(mr *MessageReader, size int32) interface{} { func decodeTimestampTzFromBinary(mr *MessageReader, size int32) interface{} {
if size != 8 { if size != 8 {
panic("Received an invalid size for an int8") return ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size))
} }
microsecFromUnixEpochToY2K := int64(946684800 * 1000000) microsecFromUnixEpochToY2K := int64(946684800 * 1000000)
microsecSinceY2K := mr.ReadInt64() microsecSinceY2K := mr.ReadInt64()
+2 -2
View File
@@ -6,7 +6,7 @@ import (
) )
func TestDateTranscode(t *testing.T) { func TestDateTranscode(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
actualDate := time.Date(2013, 1, 2, 0, 0, 0, 0, time.Local) actualDate := time.Date(2013, 1, 2, 0, 0, 0, 0, time.Local)
@@ -34,7 +34,7 @@ func TestDateTranscode(t *testing.T) {
} }
func TestTimestampTzTranscode(t *testing.T) { func TestTimestampTzTranscode(t *testing.T) {
conn := getSharedConnection() conn := getSharedConnection(t)
inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local)