diff --git a/sanitize.go b/sanitize.go index 7a07352f..0f60b21a 100644 --- a/sanitize.go +++ b/sanitize.go @@ -62,6 +62,18 @@ func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string return strconv.FormatFloat(arg, 'f', -1, 64) case []byte: return `E'\\x` + hex.EncodeToString(arg) + `'` + case []int16: + var s string + s, err = int16SliceToArrayString(arg) + return c.QuoteString(s) + case []int32: + var s string + s, err = int32SliceToArrayString(arg) + return c.QuoteString(s) + case []int64: + var s string + s, err = int64SliceToArrayString(arg) + return c.QuoteString(s) default: err = fmt.Errorf("Unable to sanitize type: %T", arg) return "" diff --git a/sanitize_test.go b/sanitize_test.go index 1bd18270..c668ece2 100644 --- a/sanitize_test.go +++ b/sanitize_test.go @@ -44,4 +44,34 @@ func TestSanitizeSql(t *testing.T) { if san, err := conn.SanitizeSql("select $1", bytea); err != nil || san != `select E'\\x000fff11'` { t.Errorf("Failed to sanitize []byte: %v - %v", san, err) } + + int2a := make([]int16, 4) + int2a[0] = 42 + int2a[1] = 0 + int2a[2] = -1 + int2a[3] = 32123 + + if san, err := conn.SanitizeSql("select $1::int2[]", int2a); err != nil || san != `select '{42,0,-1,32123}'::int2[]` { + t.Errorf("Failed to sanitize []int16: %v - %v", san, err) + } + + int4a := make([]int32, 4) + int4a[0] = 42 + int4a[1] = 0 + int4a[2] = -1 + int4a[3] = 32123 + + if san, err := conn.SanitizeSql("select $1::int4[]", int4a); err != nil || san != `select '{42,0,-1,32123}'::int4[]` { + t.Errorf("Failed to sanitize []int32: %v - %v", san, err) + } + + int8a := make([]int64, 4) + int8a[0] = 42 + int8a[1] = 0 + int8a[2] = -1 + int8a[3] = 32123 + + if san, err := conn.SanitizeSql("select $1::int8[]", int8a); err != nil || san != `select '{42,0,-1,32123}'::int8[]` { + t.Errorf("Failed to sanitize []int64: %v - %v", san, err) + } } diff --git a/value_transcoder.go b/value_transcoder.go index 2a16184a..26dc6298 100644 --- a/value_transcoder.go +++ b/value_transcoder.go @@ -1,8 +1,10 @@ package pgx import ( + "bytes" "encoding/hex" "fmt" + "regexp" "strconv" "time" "unsafe" @@ -86,6 +88,21 @@ func init() { EncodeTo: encodeFloat8, EncodeFormat: 1} + // int2[] + ValueTranscoders[Oid(1005)] = &ValueTranscoder{ + DecodeText: decodeInt2ArrayFromText, + EncodeTo: encodeInt2Array} + + // int4[] + ValueTranscoders[Oid(1007)] = &ValueTranscoder{ + DecodeText: decodeInt4ArrayFromText, + EncodeTo: encodeInt4Array} + + // int8[] + ValueTranscoders[Oid(1016)] = &ValueTranscoder{ + DecodeText: decodeInt8ArrayFromText, + EncodeTo: encodeInt8Array} + // varchar -- same as text ValueTranscoders[Oid(1043)] = ValueTranscoders[Oid(25)] @@ -104,6 +121,24 @@ func init() { defaultTranscoder = ValueTranscoders[Oid(25)] } +var arrayEl *regexp.Regexp = regexp.MustCompile(`[{,](?:"((?:[^"\\]|\\.)*)"|(NULL)|([^,}]+))`) + +// SplitArrayText +func SplitArrayText(text string) (elements []string) { + matches := arrayEl.FindAllStringSubmatch(text, -1) + elements = make([]string, 0, len(matches)) + for _, match := range matches { + if match[1] != "" { + elements = append(elements, match[1]) + } else if match[2] != "" { + elements = append(elements, match[2]) + } else if match[3] != "" { + elements = append(elements, match[3]) + } + } + return +} + func decodeBoolFromText(mr *MessageReader, size int32) interface{} { s := mr.ReadString(size) switch s { @@ -320,3 +355,126 @@ func encodeTimestampTz(w *MessageWriter, value interface{}) { w.Write(int32(len(s))) w.WriteString(s) } + +func decodeInt2ArrayFromText(mr *MessageReader, size int32) interface{} { + s := mr.ReadString(size) + + elements := SplitArrayText(s) + + numbers := make([]int16, 0, len(elements)) + + for _, e := range elements { + n, err := strconv.ParseInt(e, 10, 16) + if err != nil { + return ProtocolError(fmt.Sprintf("Received invalid int2[]: %v", s)) + } + numbers = append(numbers, int16(n)) + } + + return numbers +} + +func int16SliceToArrayString(nums []int16) (string, error) { + w := newMessageWriter(&bytes.Buffer{}) + w.WriteString("{") + for i, n := range nums { + if i > 0 { + w.WriteString(",") + } + w.WriteString(strconv.FormatInt(int64(n), 10)) + } + w.WriteString("}") + return w.buf.String(), w.Err +} + +func encodeInt2Array(w *MessageWriter, value interface{}) { + v := value.([]int16) + s, err := int16SliceToArrayString(v) + if err != nil { + w.Err = fmt.Errorf("Failed to encode []int16: %v", err) + } + w.Write(int32(len(s))) + w.WriteString(s) +} + +func decodeInt4ArrayFromText(mr *MessageReader, size int32) interface{} { + s := mr.ReadString(size) + + elements := SplitArrayText(s) + + numbers := make([]int32, 0, len(elements)) + + for _, e := range elements { + n, err := strconv.ParseInt(e, 10, 16) + if err != nil { + return ProtocolError(fmt.Sprintf("Received invalid int4[]: %v", s)) + } + numbers = append(numbers, int32(n)) + } + + return numbers +} + +func int32SliceToArrayString(nums []int32) (string, error) { + w := newMessageWriter(&bytes.Buffer{}) + w.WriteString("{") + for i, n := range nums { + if i > 0 { + w.WriteString(",") + } + w.WriteString(strconv.FormatInt(int64(n), 10)) + } + w.WriteString("}") + return w.buf.String(), w.Err +} + +func encodeInt4Array(w *MessageWriter, value interface{}) { + v := value.([]int32) + s, err := int32SliceToArrayString(v) + if err != nil { + w.Err = fmt.Errorf("Failed to encode []int32: %v", err) + } + w.Write(int32(len(s))) + w.WriteString(s) +} + +func decodeInt8ArrayFromText(mr *MessageReader, size int32) interface{} { + s := mr.ReadString(size) + + elements := SplitArrayText(s) + + numbers := make([]int64, 0, len(elements)) + + for _, e := range elements { + n, err := strconv.ParseInt(e, 10, 16) + if err != nil { + return ProtocolError(fmt.Sprintf("Received invalid int8[]: %v", s)) + } + numbers = append(numbers, int64(n)) + } + + return numbers +} + +func int64SliceToArrayString(nums []int64) (string, error) { + w := newMessageWriter(&bytes.Buffer{}) + w.WriteString("{") + for i, n := range nums { + if i > 0 { + w.WriteString(",") + } + w.WriteString(strconv.FormatInt(int64(n), 10)) + } + w.WriteString("}") + return w.buf.String(), w.Err +} + +func encodeInt8Array(w *MessageWriter, value interface{}) { + v := value.([]int64) + s, err := int64SliceToArrayString(v) + if err != nil { + w.Err = fmt.Errorf("Failed to encode []int64: %v", err) + } + w.Write(int32(len(s))) + w.WriteString(s) +} diff --git a/value_transcoder_test.go b/value_transcoder_test.go index 0d0227a4..cce7cd62 100644 --- a/value_transcoder_test.go +++ b/value_transcoder_test.go @@ -60,3 +60,96 @@ func TestTimestampTzTranscode(t *testing.T) { t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) } } + +func TestInt2SliceTranscode(t *testing.T) { + testEqual := func(a, b []int16) { + if len(a) != len(b) { + t.Errorf("Did not transcode []int16 successfully: %v is not %v", a, b) + } + for i := range a { + if a[i] != b[i] { + t.Errorf("Did not transcode []int16 successfully: %v is not %v", a, b) + } + } + } + + conn := getSharedConnection(t) + + inputNumbers := []int16{1, 2, 3, 4, 5, 6, 7, 8} + var outputNumbers []int16 + + outputNumbers = mustSelectValue(t, conn, "select $1::int2[]", inputNumbers).([]int16) + testEqual(inputNumbers, outputNumbers) + + mustPrepare(t, conn, "testTranscode", "select $1::int2[]") + defer func() { + if err := conn.Deallocate("testTranscode"); err != nil { + t.Fatalf("Unable to deallocate prepared statement: %v", err) + } + }() + + outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int16) + testEqual(inputNumbers, outputNumbers) +} + +func TestInt4SliceTranscode(t *testing.T) { + testEqual := func(a, b []int32) { + if len(a) != len(b) { + t.Errorf("Did not transcode []int32 successfully: %v is not %v", a, b) + } + for i := range a { + if a[i] != b[i] { + t.Errorf("Did not transcode []int32 successfully: %v is not %v", a, b) + } + } + } + + conn := getSharedConnection(t) + + inputNumbers := []int32{1, 2, 3, 4, 5, 6, 7, 8} + var outputNumbers []int32 + + outputNumbers = mustSelectValue(t, conn, "select $1::int4[]", inputNumbers).([]int32) + testEqual(inputNumbers, outputNumbers) + + mustPrepare(t, conn, "testTranscode", "select $1::int4[]") + defer func() { + if err := conn.Deallocate("testTranscode"); err != nil { + t.Fatalf("Unable to deallocate prepared statement: %v", err) + } + }() + + outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int32) + testEqual(inputNumbers, outputNumbers) +} + +func TestInt8SliceTranscode(t *testing.T) { + testEqual := func(a, b []int64) { + if len(a) != len(b) { + t.Errorf("Did not transcode []int64 successfully: %v is not %v", a, b) + } + for i := range a { + if a[i] != b[i] { + t.Errorf("Did not transcode []int64 successfully: %v is not %v", a, b) + } + } + } + + conn := getSharedConnection(t) + + inputNumbers := []int64{1, 2, 3, 4, 5, 6, 7, 8} + var outputNumbers []int64 + + outputNumbers = mustSelectValue(t, conn, "select $1::int8[]", inputNumbers).([]int64) + testEqual(inputNumbers, outputNumbers) + + mustPrepare(t, conn, "testTranscode", "select $1::int8[]") + defer func() { + if err := conn.Deallocate("testTranscode"); err != nil { + t.Fatalf("Unable to deallocate prepared statement: %v", err) + } + }() + + outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int64) + testEqual(inputNumbers, outputNumbers) +}