Add basic array transcoding for int16, int32, and int64
This commit is contained in:
+12
@@ -62,6 +62,18 @@ func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string
|
|||||||
return strconv.FormatFloat(arg, 'f', -1, 64)
|
return strconv.FormatFloat(arg, 'f', -1, 64)
|
||||||
case []byte:
|
case []byte:
|
||||||
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||||
|
case []int16:
|
||||||
|
var s string
|
||||||
|
s, err = int16SliceToArrayString(arg)
|
||||||
|
return c.QuoteString(s)
|
||||||
|
case []int32:
|
||||||
|
var s string
|
||||||
|
s, err = int32SliceToArrayString(arg)
|
||||||
|
return c.QuoteString(s)
|
||||||
|
case []int64:
|
||||||
|
var s string
|
||||||
|
s, err = int64SliceToArrayString(arg)
|
||||||
|
return c.QuoteString(s)
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("Unable to sanitize type: %T", arg)
|
err = fmt.Errorf("Unable to sanitize type: %T", arg)
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -44,4 +44,34 @@ func TestSanitizeSql(t *testing.T) {
|
|||||||
if san, err := conn.SanitizeSql("select $1", bytea); err != nil || san != `select E'\\x000fff11'` {
|
if san, err := conn.SanitizeSql("select $1", bytea); err != nil || san != `select E'\\x000fff11'` {
|
||||||
t.Errorf("Failed to sanitize []byte: %v - %v", san, err)
|
t.Errorf("Failed to sanitize []byte: %v - %v", san, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int2a := make([]int16, 4)
|
||||||
|
int2a[0] = 42
|
||||||
|
int2a[1] = 0
|
||||||
|
int2a[2] = -1
|
||||||
|
int2a[3] = 32123
|
||||||
|
|
||||||
|
if san, err := conn.SanitizeSql("select $1::int2[]", int2a); err != nil || san != `select '{42,0,-1,32123}'::int2[]` {
|
||||||
|
t.Errorf("Failed to sanitize []int16: %v - %v", san, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
int4a := make([]int32, 4)
|
||||||
|
int4a[0] = 42
|
||||||
|
int4a[1] = 0
|
||||||
|
int4a[2] = -1
|
||||||
|
int4a[3] = 32123
|
||||||
|
|
||||||
|
if san, err := conn.SanitizeSql("select $1::int4[]", int4a); err != nil || san != `select '{42,0,-1,32123}'::int4[]` {
|
||||||
|
t.Errorf("Failed to sanitize []int32: %v - %v", san, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
int8a := make([]int64, 4)
|
||||||
|
int8a[0] = 42
|
||||||
|
int8a[1] = 0
|
||||||
|
int8a[2] = -1
|
||||||
|
int8a[3] = 32123
|
||||||
|
|
||||||
|
if san, err := conn.SanitizeSql("select $1::int8[]", int8a); err != nil || san != `select '{42,0,-1,32123}'::int8[]` {
|
||||||
|
t.Errorf("Failed to sanitize []int64: %v - %v", san, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package pgx
|
package pgx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@@ -86,6 +88,21 @@ func init() {
|
|||||||
EncodeTo: encodeFloat8,
|
EncodeTo: encodeFloat8,
|
||||||
EncodeFormat: 1}
|
EncodeFormat: 1}
|
||||||
|
|
||||||
|
// int2[]
|
||||||
|
ValueTranscoders[Oid(1005)] = &ValueTranscoder{
|
||||||
|
DecodeText: decodeInt2ArrayFromText,
|
||||||
|
EncodeTo: encodeInt2Array}
|
||||||
|
|
||||||
|
// int4[]
|
||||||
|
ValueTranscoders[Oid(1007)] = &ValueTranscoder{
|
||||||
|
DecodeText: decodeInt4ArrayFromText,
|
||||||
|
EncodeTo: encodeInt4Array}
|
||||||
|
|
||||||
|
// int8[]
|
||||||
|
ValueTranscoders[Oid(1016)] = &ValueTranscoder{
|
||||||
|
DecodeText: decodeInt8ArrayFromText,
|
||||||
|
EncodeTo: encodeInt8Array}
|
||||||
|
|
||||||
// varchar -- same as text
|
// varchar -- same as text
|
||||||
ValueTranscoders[Oid(1043)] = ValueTranscoders[Oid(25)]
|
ValueTranscoders[Oid(1043)] = ValueTranscoders[Oid(25)]
|
||||||
|
|
||||||
@@ -104,6 +121,24 @@ func init() {
|
|||||||
defaultTranscoder = ValueTranscoders[Oid(25)]
|
defaultTranscoder = ValueTranscoders[Oid(25)]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var arrayEl *regexp.Regexp = regexp.MustCompile(`[{,](?:"((?:[^"\\]|\\.)*)"|(NULL)|([^,}]+))`)
|
||||||
|
|
||||||
|
// SplitArrayText
|
||||||
|
func SplitArrayText(text string) (elements []string) {
|
||||||
|
matches := arrayEl.FindAllStringSubmatch(text, -1)
|
||||||
|
elements = make([]string, 0, len(matches))
|
||||||
|
for _, match := range matches {
|
||||||
|
if match[1] != "" {
|
||||||
|
elements = append(elements, match[1])
|
||||||
|
} else if match[2] != "" {
|
||||||
|
elements = append(elements, match[2])
|
||||||
|
} else if match[3] != "" {
|
||||||
|
elements = append(elements, match[3])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func decodeBoolFromText(mr *MessageReader, size int32) interface{} {
|
func decodeBoolFromText(mr *MessageReader, size int32) interface{} {
|
||||||
s := mr.ReadString(size)
|
s := mr.ReadString(size)
|
||||||
switch s {
|
switch s {
|
||||||
@@ -320,3 +355,126 @@ func encodeTimestampTz(w *MessageWriter, value interface{}) {
|
|||||||
w.Write(int32(len(s)))
|
w.Write(int32(len(s)))
|
||||||
w.WriteString(s)
|
w.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeInt2ArrayFromText(mr *MessageReader, size int32) interface{} {
|
||||||
|
s := mr.ReadString(size)
|
||||||
|
|
||||||
|
elements := SplitArrayText(s)
|
||||||
|
|
||||||
|
numbers := make([]int16, 0, len(elements))
|
||||||
|
|
||||||
|
for _, e := range elements {
|
||||||
|
n, err := strconv.ParseInt(e, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return ProtocolError(fmt.Sprintf("Received invalid int2[]: %v", s))
|
||||||
|
}
|
||||||
|
numbers = append(numbers, int16(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
return numbers
|
||||||
|
}
|
||||||
|
|
||||||
|
func int16SliceToArrayString(nums []int16) (string, error) {
|
||||||
|
w := newMessageWriter(&bytes.Buffer{})
|
||||||
|
w.WriteString("{")
|
||||||
|
for i, n := range nums {
|
||||||
|
if i > 0 {
|
||||||
|
w.WriteString(",")
|
||||||
|
}
|
||||||
|
w.WriteString(strconv.FormatInt(int64(n), 10))
|
||||||
|
}
|
||||||
|
w.WriteString("}")
|
||||||
|
return w.buf.String(), w.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeInt2Array(w *MessageWriter, value interface{}) {
|
||||||
|
v := value.([]int16)
|
||||||
|
s, err := int16SliceToArrayString(v)
|
||||||
|
if err != nil {
|
||||||
|
w.Err = fmt.Errorf("Failed to encode []int16: %v", err)
|
||||||
|
}
|
||||||
|
w.Write(int32(len(s)))
|
||||||
|
w.WriteString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeInt4ArrayFromText(mr *MessageReader, size int32) interface{} {
|
||||||
|
s := mr.ReadString(size)
|
||||||
|
|
||||||
|
elements := SplitArrayText(s)
|
||||||
|
|
||||||
|
numbers := make([]int32, 0, len(elements))
|
||||||
|
|
||||||
|
for _, e := range elements {
|
||||||
|
n, err := strconv.ParseInt(e, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return ProtocolError(fmt.Sprintf("Received invalid int4[]: %v", s))
|
||||||
|
}
|
||||||
|
numbers = append(numbers, int32(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
return numbers
|
||||||
|
}
|
||||||
|
|
||||||
|
func int32SliceToArrayString(nums []int32) (string, error) {
|
||||||
|
w := newMessageWriter(&bytes.Buffer{})
|
||||||
|
w.WriteString("{")
|
||||||
|
for i, n := range nums {
|
||||||
|
if i > 0 {
|
||||||
|
w.WriteString(",")
|
||||||
|
}
|
||||||
|
w.WriteString(strconv.FormatInt(int64(n), 10))
|
||||||
|
}
|
||||||
|
w.WriteString("}")
|
||||||
|
return w.buf.String(), w.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeInt4Array(w *MessageWriter, value interface{}) {
|
||||||
|
v := value.([]int32)
|
||||||
|
s, err := int32SliceToArrayString(v)
|
||||||
|
if err != nil {
|
||||||
|
w.Err = fmt.Errorf("Failed to encode []int32: %v", err)
|
||||||
|
}
|
||||||
|
w.Write(int32(len(s)))
|
||||||
|
w.WriteString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeInt8ArrayFromText(mr *MessageReader, size int32) interface{} {
|
||||||
|
s := mr.ReadString(size)
|
||||||
|
|
||||||
|
elements := SplitArrayText(s)
|
||||||
|
|
||||||
|
numbers := make([]int64, 0, len(elements))
|
||||||
|
|
||||||
|
for _, e := range elements {
|
||||||
|
n, err := strconv.ParseInt(e, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return ProtocolError(fmt.Sprintf("Received invalid int8[]: %v", s))
|
||||||
|
}
|
||||||
|
numbers = append(numbers, int64(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
return numbers
|
||||||
|
}
|
||||||
|
|
||||||
|
func int64SliceToArrayString(nums []int64) (string, error) {
|
||||||
|
w := newMessageWriter(&bytes.Buffer{})
|
||||||
|
w.WriteString("{")
|
||||||
|
for i, n := range nums {
|
||||||
|
if i > 0 {
|
||||||
|
w.WriteString(",")
|
||||||
|
}
|
||||||
|
w.WriteString(strconv.FormatInt(int64(n), 10))
|
||||||
|
}
|
||||||
|
w.WriteString("}")
|
||||||
|
return w.buf.String(), w.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeInt8Array(w *MessageWriter, value interface{}) {
|
||||||
|
v := value.([]int64)
|
||||||
|
s, err := int64SliceToArrayString(v)
|
||||||
|
if err != nil {
|
||||||
|
w.Err = fmt.Errorf("Failed to encode []int64: %v", err)
|
||||||
|
}
|
||||||
|
w.Write(int32(len(s)))
|
||||||
|
w.WriteString(s)
|
||||||
|
}
|
||||||
|
|||||||
@@ -60,3 +60,96 @@ func TestTimestampTzTranscode(t *testing.T) {
|
|||||||
t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime)
|
t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInt2SliceTranscode(t *testing.T) {
|
||||||
|
testEqual := func(a, b []int16) {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
t.Errorf("Did not transcode []int16 successfully: %v is not %v", a, b)
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
t.Errorf("Did not transcode []int16 successfully: %v is not %v", a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := getSharedConnection(t)
|
||||||
|
|
||||||
|
inputNumbers := []int16{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
|
var outputNumbers []int16
|
||||||
|
|
||||||
|
outputNumbers = mustSelectValue(t, conn, "select $1::int2[]", inputNumbers).([]int16)
|
||||||
|
testEqual(inputNumbers, outputNumbers)
|
||||||
|
|
||||||
|
mustPrepare(t, conn, "testTranscode", "select $1::int2[]")
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Deallocate("testTranscode"); err != nil {
|
||||||
|
t.Fatalf("Unable to deallocate prepared statement: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int16)
|
||||||
|
testEqual(inputNumbers, outputNumbers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInt4SliceTranscode(t *testing.T) {
|
||||||
|
testEqual := func(a, b []int32) {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
t.Errorf("Did not transcode []int32 successfully: %v is not %v", a, b)
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
t.Errorf("Did not transcode []int32 successfully: %v is not %v", a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := getSharedConnection(t)
|
||||||
|
|
||||||
|
inputNumbers := []int32{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
|
var outputNumbers []int32
|
||||||
|
|
||||||
|
outputNumbers = mustSelectValue(t, conn, "select $1::int4[]", inputNumbers).([]int32)
|
||||||
|
testEqual(inputNumbers, outputNumbers)
|
||||||
|
|
||||||
|
mustPrepare(t, conn, "testTranscode", "select $1::int4[]")
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Deallocate("testTranscode"); err != nil {
|
||||||
|
t.Fatalf("Unable to deallocate prepared statement: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int32)
|
||||||
|
testEqual(inputNumbers, outputNumbers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInt8SliceTranscode(t *testing.T) {
|
||||||
|
testEqual := func(a, b []int64) {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
t.Errorf("Did not transcode []int64 successfully: %v is not %v", a, b)
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
t.Errorf("Did not transcode []int64 successfully: %v is not %v", a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := getSharedConnection(t)
|
||||||
|
|
||||||
|
inputNumbers := []int64{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
|
var outputNumbers []int64
|
||||||
|
|
||||||
|
outputNumbers = mustSelectValue(t, conn, "select $1::int8[]", inputNumbers).([]int64)
|
||||||
|
testEqual(inputNumbers, outputNumbers)
|
||||||
|
|
||||||
|
mustPrepare(t, conn, "testTranscode", "select $1::int8[]")
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Deallocate("testTranscode"); err != nil {
|
||||||
|
t.Fatalf("Unable to deallocate prepared statement: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int64)
|
||||||
|
testEqual(inputNumbers, outputNumbers)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user