diff --git a/conn.go b/conn.go index c1388c7a..53f84c75 100644 --- a/conn.go +++ b/conn.go @@ -519,8 +519,6 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = arg.Encode(wbuf, oid) case string: err = encodeText(wbuf, arguments[i]) - case map[string]string: - err = encodeHstore(wbuf, arguments[i]) default: switch oid { case BoolOid: diff --git a/query.go b/query.go index 81e5561b..d1e8b1cd 100644 --- a/query.go +++ b/query.go @@ -254,8 +254,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { default: rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)) } - case *map[string]string: - *d = decodeHstore(vr) case Scanner: err = d.Scan(vr) if err != nil { diff --git a/values.go b/values.go index c22f8557..436d57d8 100644 --- a/values.go +++ b/values.go @@ -421,7 +421,66 @@ func (n NullTime) Encode(w *WriteBuf, oid Oid) error { return encodeTimestampTz(w, n.Time) } -// NullHstore represents an hstore that can be null or have null values +//Hstore represents an hstore column. It does not support a null column or null +// key values (use NullHstore for this). Hstore implements the Scanner and +// TextEncoder interfaces so it may be used both as an argument to Query[Row] +// and a destination for Scan for prepared and unprepared queries. +type Hstore map[string]string + +func (h *Hstore) Scan(vr *ValueReader) error { + //oid for hstore not standardized, so we check its type name + if vr.Type().DataTypeName != "hstore" { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into Hstore", vr.Type().DataTypeName))) + return nil + } + + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null column into Hstore")) + return nil + } + + switch vr.Type().FormatCode { + case TextFormatCode: + m, err := parseHstoreToMap(vr.ReadString(vr.Len())) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) + return nil + } + hm := Hstore(m) + *h = hm + return nil + case BinaryFormatCode: + vr.Fatal(ProtocolError("Can't decode binary hstore")) + return nil + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } +} + +func (h Hstore) FormatCode() int16 { return TextFormatCode } + +func (h Hstore) Encode(w *WriteBuf, oid Oid) error { + var buf bytes.Buffer + + i := 0 + for k, v := range h { + i++ + ks := strings.Replace(k, `\`, `\\`, -1) + ks = strings.Replace(ks, `"`, `\"`, -1) + vs := strings.Replace(v, `\`, `\\`, -1) + vs = strings.Replace(vs, `"`, `\"`, -1) + buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) + if i < len(h) { + buf.WriteString(", ") + } + } + w.WriteInt32(int32(buf.Len())) + w.WriteBytes(buf.Bytes()) + return nil +} + +// NullHstore represents an hstore column that can be null or have null values // associated with its keys. NullHstore implements the Scanner and TextEncoder // interfaces so it may be used both as an argument to Query[Row] and a // destination for Scan for prepared and unprepared queries. @@ -437,7 +496,7 @@ type NullHstore struct { func (h *NullHstore) Scan(vr *ValueReader) error { //oid for hstore not standardized, so we check its type name if vr.Type().DataTypeName != "hstore" { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into map[string]string", vr.Type().DataTypeName))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into NullHstore", vr.Type().DataTypeName))) return nil } @@ -1122,60 +1181,6 @@ func encodeTimestamp(w *WriteBuf, value interface{}) error { return encodeText(w, s) } -func decodeHstore(vr *ValueReader) map[string]string { - //oid for hstore not standardized, so we check its type name - if vr.Type().DataTypeName != "hstore" { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into map[string]string", vr.Type().DataTypeName))) - return nil - } - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into map[string]string")) - return nil - } - - switch vr.Type().FormatCode { - case TextFormatCode: - m, err := parseHstoreToMap(vr.ReadString(vr.Len())) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) - return nil - } - return m - case BinaryFormatCode: - vr.Fatal(ProtocolError("Can't decode binary hstore")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } -} - -func encodeHstore(w *WriteBuf, value interface{}) error { - var buf bytes.Buffer - - h, ok := value.(map[string]string) - if !ok { - return fmt.Errorf("Expected map[string]string, received %T", value) - } - - i := 0 - for k, v := range h { - i++ - ks := strings.Replace(k, `\`, `\\`, -1) - ks = strings.Replace(ks, `"`, `\"`, -1) - vs := strings.Replace(v, `\`, `\\`, -1) - vs = strings.Replace(vs, `"`, `\"`, -1) - buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) - if i < len(h) { - buf.WriteString(", ") - } - } - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - return nil -} - func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { numDims := vr.ReadInt32() if numDims == 0 {