Merge branch 'add-aclitem-array' of https://github.com/manniwood/pgx into manniwood-add-aclitem-array
This commit is contained in:
@@ -0,0 +1,126 @@
|
|||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEscapeAclItem(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"foo",
|
||||||
|
"foo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`foo, "\}`,
|
||||||
|
`foo\, \"\\\}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
actual, err := escapeAclItem(tt.input)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected error %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual != tt.expected {
|
||||||
|
t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseAclItemArray(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected []AclItem
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"",
|
||||||
|
[]AclItem{},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"one",
|
||||||
|
[]AclItem{"one"},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`"one"`,
|
||||||
|
[]AclItem{"one"},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"one,two,three",
|
||||||
|
[]AclItem{"one", "two", "three"},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`"one","two","three"`,
|
||||||
|
[]AclItem{"one", "two", "three"},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`"one",two,"three"`,
|
||||||
|
[]AclItem{"one", "two", "three"},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`one,two,"three"`,
|
||||||
|
[]AclItem{"one", "two", "three"},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`"one","two",three`,
|
||||||
|
[]AclItem{"one", "two", "three"},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`"one","t w o",three`,
|
||||||
|
[]AclItem{"one", "t w o", "three"},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`"one","t, w o\"\}\\",three`,
|
||||||
|
[]AclItem{"one", `t, w o"}\`, "three"},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`"one","two",three"`,
|
||||||
|
[]AclItem{"one", "two", `three"`},
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`"one","two,"three"`,
|
||||||
|
nil,
|
||||||
|
"unexpected rune after quoted value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`"one","two","three`,
|
||||||
|
nil,
|
||||||
|
"unexpected end of quoted value",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
actual, err := parseAclItemArray(tt.input)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if tt.errMsg == "" {
|
||||||
|
t.Errorf("%d. Unexpected error %v", i, err)
|
||||||
|
} else if err.Error() != tt.errMsg {
|
||||||
|
t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error())
|
||||||
|
}
|
||||||
|
} else if tt.errMsg != "" {
|
||||||
|
t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(actual, tt.expected) {
|
||||||
|
t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -45,6 +46,7 @@ const (
|
|||||||
Float4ArrayOid = 1021
|
Float4ArrayOid = 1021
|
||||||
Float8ArrayOid = 1022
|
Float8ArrayOid = 1022
|
||||||
AclItemOid = 1033
|
AclItemOid = 1033
|
||||||
|
AclItemArrayOid = 1034
|
||||||
InetArrayOid = 1041
|
InetArrayOid = 1041
|
||||||
VarcharOid = 1043
|
VarcharOid = 1043
|
||||||
DateOid = 1082
|
DateOid = 1082
|
||||||
@@ -77,6 +79,7 @@ var DefaultTypeFormats map[string]int16
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
DefaultTypeFormats = map[string]int16{
|
DefaultTypeFormats = map[string]int16{
|
||||||
|
"_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin)
|
||||||
"_bool": BinaryFormatCode,
|
"_bool": BinaryFormatCode,
|
||||||
"_bytea": BinaryFormatCode,
|
"_bytea": BinaryFormatCode,
|
||||||
"_cidr": BinaryFormatCode,
|
"_cidr": BinaryFormatCode,
|
||||||
@@ -981,6 +984,8 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
|
|||||||
return Encode(wbuf, oid, v)
|
return Encode(wbuf, oid, v)
|
||||||
case string:
|
case string:
|
||||||
return encodeString(wbuf, oid, arg)
|
return encodeString(wbuf, oid, arg)
|
||||||
|
case []AclItem:
|
||||||
|
return encodeAclItemSlice(wbuf, oid, arg)
|
||||||
case []byte:
|
case []byte:
|
||||||
return encodeByteSlice(wbuf, oid, arg)
|
return encodeByteSlice(wbuf, oid, arg)
|
||||||
case [][]byte:
|
case [][]byte:
|
||||||
@@ -1224,6 +1229,8 @@ func Decode(vr *ValueReader, d interface{}) error {
|
|||||||
*v = decodeFloat4(vr)
|
*v = decodeFloat4(vr)
|
||||||
case *float64:
|
case *float64:
|
||||||
*v = decodeFloat8(vr)
|
*v = decodeFloat8(vr)
|
||||||
|
case *[]AclItem:
|
||||||
|
*v = decodeAclItemArray(vr)
|
||||||
case *[]bool:
|
case *[]bool:
|
||||||
*v = decodeBoolArray(vr)
|
*v = decodeBoolArray(vr)
|
||||||
case *[]int16:
|
case *[]int16:
|
||||||
@@ -2993,6 +3000,210 @@ func decodeTextArray(vr *ValueReader) []string {
|
|||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// escapeAclItem escapes an AclItem before it is added to
|
||||||
|
// its aclitem[] string representation. The PostgreSQL aclitem
|
||||||
|
// datatype itself can need escapes because it follows the
|
||||||
|
// formatting rules of SQL identifiers. Think of this function
|
||||||
|
// as escaping the escapes, so that PostgreSQL's array parser
|
||||||
|
// will do the right thing.
|
||||||
|
func escapeAclItem(acl string) (string, error) {
|
||||||
|
var escapedAclItem bytes.Buffer
|
||||||
|
reader := strings.NewReader(acl)
|
||||||
|
for {
|
||||||
|
rn, _, err := reader.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
// Here, EOF is an expected end state, not an error.
|
||||||
|
return escapedAclItem.String(), nil
|
||||||
|
}
|
||||||
|
// This error was not expected
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if needsEscape(rn) {
|
||||||
|
escapedAclItem.WriteRune('\\')
|
||||||
|
}
|
||||||
|
escapedAclItem.WriteRune(rn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// needsEscape determines whether or not a rune needs escaping
|
||||||
|
// before being placed in the textual representation of an
|
||||||
|
// aclitem[] array.
|
||||||
|
func needsEscape(rn rune) bool {
|
||||||
|
return rn == '\\' || rn == ',' || rn == '"' || rn == '}'
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeAclItemSlice encodes a slice of AclItems in
|
||||||
|
// their textual represention for PostgreSQL.
|
||||||
|
func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error {
|
||||||
|
strs := make([]string, len(aclitems))
|
||||||
|
var escapedAclItem string
|
||||||
|
var err error
|
||||||
|
for i := range strs {
|
||||||
|
escapedAclItem, err = escapeAclItem(string(aclitems[i]))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
strs[i] = string(escapedAclItem)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteRune('{')
|
||||||
|
buf.WriteString(strings.Join(strs, ","))
|
||||||
|
buf.WriteRune('}')
|
||||||
|
str := buf.String()
|
||||||
|
w.WriteInt32(int32(len(str)))
|
||||||
|
w.WriteBytes([]byte(str))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAclItemArray parses the textual representation
|
||||||
|
// of the aclitem[] type. The textual representation is chosen because
|
||||||
|
// Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin).
|
||||||
|
// See https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
|
||||||
|
// for formatting notes.
|
||||||
|
func parseAclItemArray(arr string) ([]AclItem, error) {
|
||||||
|
reader := strings.NewReader(arr)
|
||||||
|
// Difficult to guess a performant initial capacity for a slice of
|
||||||
|
// aclitems, but let's go with 5.
|
||||||
|
aclItems := make([]AclItem, 0, 5)
|
||||||
|
// A single value
|
||||||
|
aclItem := AclItem("")
|
||||||
|
for {
|
||||||
|
// Grab the first/next/last rune to see if we are dealing with a
|
||||||
|
// quoted value, an unquoted value, or the end of the string.
|
||||||
|
rn, _, err := reader.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
// Here, EOF is an expected end state, not an error.
|
||||||
|
return aclItems, nil
|
||||||
|
}
|
||||||
|
// This error was not expected
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if rn == '"' {
|
||||||
|
// Discard the opening quote of the quoted value.
|
||||||
|
aclItem, err = parseQuotedAclItem(reader)
|
||||||
|
} else {
|
||||||
|
// We have just read the first rune of an unquoted (bare) value;
|
||||||
|
// put it back so that ParseBareValue can read it.
|
||||||
|
err := reader.UnreadRune()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
aclItem, err = parseBareAclItem(reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
// Here, EOF is an expected end state, not an error..
|
||||||
|
aclItems = append(aclItems, aclItem)
|
||||||
|
return aclItems, nil
|
||||||
|
}
|
||||||
|
// This error was not expected.
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
aclItems = append(aclItems, aclItem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseBareAclItem parses a bare (unquoted) aclitem from reader
|
||||||
|
func parseBareAclItem(reader *strings.Reader) (AclItem, error) {
|
||||||
|
var aclItem bytes.Buffer
|
||||||
|
for {
|
||||||
|
rn, _, err := reader.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
// Return the read value in case the error is a harmless io.EOF.
|
||||||
|
// (io.EOF marks the end of a bare aclitem at the end of a string)
|
||||||
|
return AclItem(aclItem.String()), err
|
||||||
|
}
|
||||||
|
if rn == ',' {
|
||||||
|
// A comma marks the end of a bare aclitem.
|
||||||
|
return AclItem(aclItem.String()), nil
|
||||||
|
} else {
|
||||||
|
aclItem.WriteRune(rn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseQuotedAclItem parses an aclitem which is in double quotes from reader
|
||||||
|
func parseQuotedAclItem(reader *strings.Reader) (AclItem, error) {
|
||||||
|
var aclItem bytes.Buffer
|
||||||
|
for {
|
||||||
|
rn, escaped, err := readPossiblyEscapedRune(reader)
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
// Even when it is the last value, the final rune of
|
||||||
|
// a quoted aclitem should be the final closing quote, not io.EOF.
|
||||||
|
return AclItem(""), fmt.Errorf("unexpected end of quoted value")
|
||||||
|
}
|
||||||
|
// Return the read aclitem in case the error is a harmless io.EOF,
|
||||||
|
// which will be determined by the caller.
|
||||||
|
return AclItem(aclItem.String()), err
|
||||||
|
}
|
||||||
|
if !escaped && rn == '"' {
|
||||||
|
// An unescaped double quote marks the end of a quoted value.
|
||||||
|
// The next rune should either be a comma or the end of the string.
|
||||||
|
rn, _, err := reader.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
// Return the read value in case the error is a harmless io.EOF,
|
||||||
|
// which will be determined by the caller.
|
||||||
|
return AclItem(aclItem.String()), err
|
||||||
|
}
|
||||||
|
if rn != ',' {
|
||||||
|
return AclItem(""), fmt.Errorf("unexpected rune after quoted value")
|
||||||
|
}
|
||||||
|
return AclItem(aclItem.String()), nil
|
||||||
|
}
|
||||||
|
aclItem.WriteRune(rn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the next rune from r, unless it is a backslash;
|
||||||
|
// in that case, it returns the rune after the backslash. The second
|
||||||
|
// return value tells us whether or not the rune was
|
||||||
|
// preceeded by a backslash (escaped).
|
||||||
|
func readPossiblyEscapedRune(reader *strings.Reader) (rune, bool, error) {
|
||||||
|
rn, _, err := reader.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
if rn == '\\' {
|
||||||
|
// Discard the backslash and read the next rune.
|
||||||
|
rn, _, err = reader.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
return rn, true, nil
|
||||||
|
}
|
||||||
|
return rn, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeAclItemArray(vr *ValueReader) []AclItem {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into []AclItem"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
str := vr.ReadString(vr.Len())
|
||||||
|
|
||||||
|
// Short-circuit empty array.
|
||||||
|
if str == "{}" {
|
||||||
|
return []AclItem{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the '{' at the front and the '}' at the end,
|
||||||
|
// so that parseAclItemArray doesn't have to deal with them.
|
||||||
|
str = str[1 : len(str)-1]
|
||||||
|
aclItems, err := parseAclItemArray(str)
|
||||||
|
if err != nil {
|
||||||
|
vr.Fatal(ProtocolError(err.Error()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return aclItems
|
||||||
|
}
|
||||||
|
|
||||||
func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error {
|
func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error {
|
||||||
var elOid Oid
|
var elOid Oid
|
||||||
switch oid {
|
switch oid {
|
||||||
|
|||||||
@@ -643,6 +643,52 @@ func TestNullX(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) {
|
||||||
|
if !reflect.DeepEqual(query, scan) {
|
||||||
|
t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAclArrayDecoding(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
sql := "select $1::aclitem[]"
|
||||||
|
var scan []pgx.AclItem
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
query []pgx.AclItem
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
[]pgx.AclItem{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]pgx.AclItem{"=r/postgres"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, tt := range tests {
|
||||||
|
err := conn.QueryRow(sql, tt.query).Scan(&scan)
|
||||||
|
if err != nil {
|
||||||
|
// t.Errorf(`%d. error reading array: %v`, i, err)
|
||||||
|
t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query)
|
||||||
|
if pgerr, ok := err.(pgx.PgError); ok {
|
||||||
|
t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
assertAclItemSlicesEqual(t, tt.query, scan)
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestArrayDecoding(t *testing.T) {
|
func TestArrayDecoding(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user