diff --git a/aclitem_parse_test.go b/aclitem_parse_test.go new file mode 100644 index 00000000..5c7c748f --- /dev/null +++ b/aclitem_parse_test.go @@ -0,0 +1,126 @@ +package pgx + +import ( + "reflect" + "testing" +) + +func TestEscapeAclItem(t *testing.T) { + tests := []struct { + input string + expected string + }{ + { + "foo", + "foo", + }, + { + `foo, "\}`, + `foo\, \"\\\}`, + }, + } + + for i, tt := range tests { + actual, err := escapeAclItem(tt.input) + + if err != nil { + t.Errorf("%d. Unexpected error %v", i, err) + } + + if actual != tt.expected { + t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual) + } + } +} + +func TestParseAclItemArray(t *testing.T) { + tests := []struct { + input string + expected []AclItem + errMsg string + }{ + { + "", + []AclItem{}, + "", + }, + { + "one", + []AclItem{"one"}, + "", + }, + { + `"one"`, + []AclItem{"one"}, + "", + }, + { + "one,two,three", + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one","two","three"`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one",two,"three"`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `one,two,"three"`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one","two",three`, + []AclItem{"one", "two", "three"}, + "", + }, + { + `"one","t w o",three`, + []AclItem{"one", "t w o", "three"}, + "", + }, + { + `"one","t, w o\"\}\\",three`, + []AclItem{"one", `t, w o"}\`, "three"}, + "", + }, + { + `"one","two",three"`, + []AclItem{"one", "two", `three"`}, + "", + }, + { + `"one","two,"three"`, + nil, + "unexpected rune after quoted value", + }, + { + `"one","two","three`, + nil, + "unexpected end of quoted value", + }, + } + + for i, tt := range tests { + actual, err := parseAclItemArray(tt.input) + + if err != nil { + if tt.errMsg == "" { + t.Errorf("%d. Unexpected error %v", i, err) + } else if err.Error() != tt.errMsg { + t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error()) + } + } else if tt.errMsg != "" { + t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg) + } + + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual) + } + } +} diff --git a/values.go b/values.go index 6cb6e429..8a7a49cb 100644 --- a/values.go +++ b/values.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "io" "math" "net" "reflect" @@ -45,6 +46,7 @@ const ( Float4ArrayOid = 1021 Float8ArrayOid = 1022 AclItemOid = 1033 + AclItemArrayOid = 1034 InetArrayOid = 1041 VarcharOid = 1043 DateOid = 1082 @@ -77,6 +79,7 @@ var DefaultTypeFormats map[string]int16 func init() { DefaultTypeFormats = map[string]int16{ + "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) "_bool": BinaryFormatCode, "_bytea": BinaryFormatCode, "_cidr": BinaryFormatCode, @@ -981,6 +984,8 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return Encode(wbuf, oid, v) case string: return encodeString(wbuf, oid, arg) + case []AclItem: + return encodeAclItemSlice(wbuf, oid, arg) case []byte: return encodeByteSlice(wbuf, oid, arg) case [][]byte: @@ -1224,6 +1229,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeFloat4(vr) case *float64: *v = decodeFloat8(vr) + case *[]AclItem: + *v = decodeAclItemArray(vr) case *[]bool: *v = decodeBoolArray(vr) case *[]int16: @@ -2993,6 +3000,210 @@ func decodeTextArray(vr *ValueReader) []string { return a } +// escapeAclItem escapes an AclItem before it is added to +// its aclitem[] string representation. The PostgreSQL aclitem +// datatype itself can need escapes because it follows the +// formatting rules of SQL identifiers. Think of this function +// as escaping the escapes, so that PostgreSQL's array parser +// will do the right thing. +func escapeAclItem(acl string) (string, error) { + var escapedAclItem bytes.Buffer + reader := strings.NewReader(acl) + for { + rn, _, err := reader.ReadRune() + if err != nil { + if err == io.EOF { + // Here, EOF is an expected end state, not an error. + return escapedAclItem.String(), nil + } + // This error was not expected + return "", err + } + if needsEscape(rn) { + escapedAclItem.WriteRune('\\') + } + escapedAclItem.WriteRune(rn) + } +} + +// needsEscape determines whether or not a rune needs escaping +// before being placed in the textual representation of an +// aclitem[] array. +func needsEscape(rn rune) bool { + return rn == '\\' || rn == ',' || rn == '"' || rn == '}' +} + +// encodeAclItemSlice encodes a slice of AclItems in +// their textual represention for PostgreSQL. +func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { + strs := make([]string, len(aclitems)) + var escapedAclItem string + var err error + for i := range strs { + escapedAclItem, err = escapeAclItem(string(aclitems[i])) + if err != nil { + return err + } + strs[i] = string(escapedAclItem) + } + + var buf bytes.Buffer + buf.WriteRune('{') + buf.WriteString(strings.Join(strs, ",")) + buf.WriteRune('}') + str := buf.String() + w.WriteInt32(int32(len(str))) + w.WriteBytes([]byte(str)) + return nil +} + +// parseAclItemArray parses the textual representation +// of the aclitem[] type. The textual representation is chosen because +// Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin). +// See https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +// for formatting notes. +func parseAclItemArray(arr string) ([]AclItem, error) { + reader := strings.NewReader(arr) + // Difficult to guess a performant initial capacity for a slice of + // aclitems, but let's go with 5. + aclItems := make([]AclItem, 0, 5) + // A single value + aclItem := AclItem("") + for { + // Grab the first/next/last rune to see if we are dealing with a + // quoted value, an unquoted value, or the end of the string. + rn, _, err := reader.ReadRune() + if err != nil { + if err == io.EOF { + // Here, EOF is an expected end state, not an error. + return aclItems, nil + } + // This error was not expected + return nil, err + } + + if rn == '"' { + // Discard the opening quote of the quoted value. + aclItem, err = parseQuotedAclItem(reader) + } else { + // We have just read the first rune of an unquoted (bare) value; + // put it back so that ParseBareValue can read it. + err := reader.UnreadRune() + if err != nil { + return nil, err + } + aclItem, err = parseBareAclItem(reader) + } + + if err != nil { + if err == io.EOF { + // Here, EOF is an expected end state, not an error.. + aclItems = append(aclItems, aclItem) + return aclItems, nil + } + // This error was not expected. + return nil, err + } + aclItems = append(aclItems, aclItem) + } +} + +// parseBareAclItem parses a bare (unquoted) aclitem from reader +func parseBareAclItem(reader *strings.Reader) (AclItem, error) { + var aclItem bytes.Buffer + for { + rn, _, err := reader.ReadRune() + if err != nil { + // Return the read value in case the error is a harmless io.EOF. + // (io.EOF marks the end of a bare aclitem at the end of a string) + return AclItem(aclItem.String()), err + } + if rn == ',' { + // A comma marks the end of a bare aclitem. + return AclItem(aclItem.String()), nil + } else { + aclItem.WriteRune(rn) + } + } +} + +// parseQuotedAclItem parses an aclitem which is in double quotes from reader +func parseQuotedAclItem(reader *strings.Reader) (AclItem, error) { + var aclItem bytes.Buffer + for { + rn, escaped, err := readPossiblyEscapedRune(reader) + if err != nil { + if err == io.EOF { + // Even when it is the last value, the final rune of + // a quoted aclitem should be the final closing quote, not io.EOF. + return AclItem(""), fmt.Errorf("unexpected end of quoted value") + } + // Return the read aclitem in case the error is a harmless io.EOF, + // which will be determined by the caller. + return AclItem(aclItem.String()), err + } + if !escaped && rn == '"' { + // An unescaped double quote marks the end of a quoted value. + // The next rune should either be a comma or the end of the string. + rn, _, err := reader.ReadRune() + if err != nil { + // Return the read value in case the error is a harmless io.EOF, + // which will be determined by the caller. + return AclItem(aclItem.String()), err + } + if rn != ',' { + return AclItem(""), fmt.Errorf("unexpected rune after quoted value") + } + return AclItem(aclItem.String()), nil + } + aclItem.WriteRune(rn) + } +} + +// Returns the next rune from r, unless it is a backslash; +// in that case, it returns the rune after the backslash. The second +// return value tells us whether or not the rune was +// preceeded by a backslash (escaped). +func readPossiblyEscapedRune(reader *strings.Reader) (rune, bool, error) { + rn, _, err := reader.ReadRune() + if err != nil { + return 0, false, err + } + if rn == '\\' { + // Discard the backslash and read the next rune. + rn, _, err = reader.ReadRune() + if err != nil { + return 0, false, err + } + return rn, true, nil + } + return rn, false, nil +} + +func decodeAclItemArray(vr *ValueReader) []AclItem { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into []AclItem")) + return nil + } + + str := vr.ReadString(vr.Len()) + + // Short-circuit empty array. + if str == "{}" { + return []AclItem{} + } + + // Remove the '{' at the front and the '}' at the end, + // so that parseAclItemArray doesn't have to deal with them. + str = str[1 : len(str)-1] + aclItems, err := parseAclItemArray(str) + if err != nil { + vr.Fatal(ProtocolError(err.Error())) + return nil + } + return aclItems +} + func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { var elOid Oid switch oid { diff --git a/values_test.go b/values_test.go index 8b85ceef..bbb22f24 100644 --- a/values_test.go +++ b/values_test.go @@ -643,6 +643,52 @@ func TestNullX(t *testing.T) { } } +func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) { + if !reflect.DeepEqual(query, scan) { + t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan) + } +} + +func TestAclArrayDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := "select $1::aclitem[]" + var scan []pgx.AclItem + + tests := []struct { + query []pgx.AclItem + }{ + { + []pgx.AclItem{}, + }, + { + []pgx.AclItem{"=r/postgres"}, + }, + { + []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, + }, + { + []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`}, + }, + } + for i, tt := range tests { + err := conn.QueryRow(sql, tt.query).Scan(&scan) + if err != nil { + // t.Errorf(`%d. error reading array: %v`, i, err) + t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query) + if pgerr, ok := err.(pgx.PgError); ok { + t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail) + } + continue + } + assertAclItemSlicesEqual(t, tt.query, scan) + ensureConnValid(t, conn) + } +} + func TestArrayDecoding(t *testing.T) { t.Parallel()