From 821605a8dd1013b2c0a824a4f39624baed1bba4a Mon Sep 17 00:00:00 2001 From: Andy Walker Date: Wed, 17 Sep 2014 16:17:23 -0400 Subject: [PATCH] Adding hstore support. map[string]string will encode to hstores and throw errors on hstores with NULL values, and there is now a NullHstore type that is basically map[string]NullString and will both accept and decode NULL values properly --- conn.go | 2 + hstore.go | 215 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ query.go | 3 +- values.go | 131 +++++++++++++++++++++++++++++++++ 4 files changed, 350 insertions(+), 1 deletion(-) create mode 100644 hstore.go diff --git a/conn.go b/conn.go index 53f84c75..c1388c7a 100644 --- a/conn.go +++ b/conn.go @@ -519,6 +519,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = arg.Encode(wbuf, oid) case string: err = encodeText(wbuf, arguments[i]) + case map[string]string: + err = encodeHstore(wbuf, arguments[i]) default: switch oid { case BoolOid: diff --git a/hstore.go b/hstore.go new file mode 100644 index 00000000..a7b362df --- /dev/null +++ b/hstore.go @@ -0,0 +1,215 @@ +package pgx + +import ( + "bytes" + "errors" + "fmt" + "unicode" + "unicode/utf8" +) + +const ( + hsPre = iota + hsKey + hsSep + hsVal + hsNul + hsNext + hsEnd +) + +type hstoreParser struct { + str string + pos int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + } +} + +func (p *hstoreParser) Consume() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, w := utf8.DecodeRuneInString(p.str[p.pos:]) + p.pos += w + return +} + +func (p *hstoreParser) Peek() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) + return +} + +func parseHstoreToMap(s string) (m map[string]string, err error) { + keys, values, err := ParseHstore(s) + if err != nil { + return + } + m = make(map[string]string, len(keys)) + for i, key := range keys { + if !values[i].Valid { + err = fmt.Errorf("key '%s' has NULL value", key) + m = nil + return + } + m[key] = values[i].String + } + return +} + +func parseHstoreToNullHstore(s string) (store map[string]NullString, err error) { + keys, values, err := ParseHstore(s) + if err != nil { + return + } + + store = make(map[string]NullString, len(keys)) + + for i, key := range keys { + store[key] = values[i] + } + return +} + +func ParseHstore(s string) (k []string, v []NullString, err error) { + + buf := bytes.Buffer{} + keys := []string{} + values := []NullString{} + p := newHSP(s) + + r, end := p.Consume() + state := hsPre + + for !end { + switch state { + case hsPre: + if r == '"' { + state = hsKey + } else { + err = errors.New("String does not begin with \"") + } + case hsKey: + switch r { + case '"': //End of the key + if buf.Len() == 0 { + err = errors.New("Empty Key is invalid") + } else { + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep + } + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsSep: + if r == '=' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=', expecting '>'") + case r == '>': + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") + case r == '"': + state = hsVal + case r == 'N': + state = hsNul + default: + err = fmt.Errorf("Invalid character '%s' after '=>', expecting '\"' or 'NULL'") + } + default: + err = fmt.Errorf("Invalid character after '=', expecting '>'") + } + } else { + err = fmt.Errorf("Invalid character '%s' after value, expecting '='", r) + } + case hsVal: + switch r { + case '"': //End of the value + values = append(values, NullString{String: buf.String(), Valid: true}) + buf = bytes.Buffer{} + state = hsNext + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsNul: + nulBuf := make([]rune, 3) + nulBuf[0] = r + for i := 1; i < 3; i++ { + r, end = p.Consume() + if end { + err = errors.New("Found EOS in NULL value") + return + } + nulBuf[i] = r + } + if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { + values = append(values, NullString{String: "", Valid: false}) + state = hsNext + } else { + err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + } + case hsNext: + if r == ',' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after ',', expcting space") + case (unicode.IsSpace(r)): + r, end = p.Consume() + state = hsKey + default: + err = fmt.Errorf("Invalid character '%s' after ', ', expecting \"", r) + } + } else { + err = fmt.Errorf("Invalid character '%s' after value, expecting ','", r) + } + } + + if err != nil { + return + } + r, end = p.Consume() + } + if state != hsNext { + err = errors.New("Improperly formatted hstore") + return + } + k = keys + v = values + return +} diff --git a/query.go b/query.go index d6ecbdba..81e5561b 100644 --- a/query.go +++ b/query.go @@ -254,7 +254,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { default: rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)) } - + case *map[string]string: + *d = decodeHstore(vr) case Scanner: err = d.Scan(vr) if err != nil { diff --git a/values.go b/values.go index c216a085..c22f8557 100644 --- a/values.go +++ b/values.go @@ -1,10 +1,12 @@ package pgx import ( + "bytes" "encoding/hex" "fmt" "math" "strconv" + "strings" "time" "unsafe" ) @@ -419,6 +421,81 @@ func (n NullTime) Encode(w *WriteBuf, oid Oid) error { return encodeTimestampTz(w, n.Time) } +// NullHstore represents an hstore that can be null or have null values +// associated with its keys. NullHstore implements the Scanner and TextEncoder +// interfaces so it may be used both as an argument to Query[Row] and a +// destination for Scan for prepared and unprepared queries. +// +// If Valid is false, then the value of the entire hstore column is NULL +// If any of the NullString values in Store has Valid set to false, the key +// appears in the hstore column, but its value is explicitly set to NULL +type NullHstore struct { + Store map[string]NullString + Valid bool +} + +func (h *NullHstore) Scan(vr *ValueReader) error { + //oid for hstore not standardized, so we check its type name + if vr.Type().DataTypeName != "hstore" { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into map[string]string", vr.Type().DataTypeName))) + return nil + } + + if vr.Len() == -1 { + h.Valid = false + return nil + } + + switch vr.Type().FormatCode { + case TextFormatCode: + store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len())) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) + return nil + } + h.Valid = true + h.Store = store + return nil + case BinaryFormatCode: + vr.Fatal(ProtocolError("Can't decode binary hstore")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } +} + +func (h NullHstore) FormatCode() int16 { return TextFormatCode } + +func (h NullHstore) Encode(w *WriteBuf, oid Oid) error { + var buf bytes.Buffer + + if !h.Valid { + w.WriteInt32(-1) + return nil + } + + i := 0 + for k, v := range h.Store { + i++ + ks := strings.Replace(k, `\`, `\\`, -1) + ks = strings.Replace(ks, `"`, `\"`, -1) + if v.Valid { + vs := strings.Replace(v.String, `\`, `\\`, -1) + vs = strings.Replace(vs, `"`, `\"`, -1) + buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) + } else { + buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks)) + } + if i < len(h.Store) { + buf.WriteString(", ") + } + } + w.WriteInt32(int32(buf.Len())) + w.WriteBytes(buf.Bytes()) + return nil +} + func decodeBool(vr *ValueReader) bool { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into bool")) @@ -1045,6 +1122,60 @@ func encodeTimestamp(w *WriteBuf, value interface{}) error { return encodeText(w, s) } +func decodeHstore(vr *ValueReader) map[string]string { + //oid for hstore not standardized, so we check its type name + if vr.Type().DataTypeName != "hstore" { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into map[string]string", vr.Type().DataTypeName))) + return nil + } + + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into map[string]string")) + return nil + } + + switch vr.Type().FormatCode { + case TextFormatCode: + m, err := parseHstoreToMap(vr.ReadString(vr.Len())) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) + return nil + } + return m + case BinaryFormatCode: + vr.Fatal(ProtocolError("Can't decode binary hstore")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } +} + +func encodeHstore(w *WriteBuf, value interface{}) error { + var buf bytes.Buffer + + h, ok := value.(map[string]string) + if !ok { + return fmt.Errorf("Expected map[string]string, received %T", value) + } + + i := 0 + for k, v := range h { + i++ + ks := strings.Replace(k, `\`, `\\`, -1) + ks = strings.Replace(ks, `"`, `\"`, -1) + vs := strings.Replace(v, `\`, `\\`, -1) + vs = strings.Replace(vs, `"`, `\"`, -1) + buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) + if i < len(h) { + buf.WriteString(", ") + } + } + w.WriteInt32(int32(buf.Len())) + w.WriteBytes(buf.Bytes()) + return nil +} + func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { numDims := vr.ReadInt32() if numDims == 0 {