2
0

Adding hstore support. map[string]string will encode to hstores and throw errors on hstores with NULL values, and there is now a NullHstore type that is basically map[string]NullString and will both accept and decode NULL values properly

This commit is contained in:
Andy Walker
2014-09-17 16:17:23 -04:00
committed by Jack Christensen
parent 0441bcd8e4
commit 821605a8dd
4 changed files with 350 additions and 1 deletions
+131
View File
@@ -1,10 +1,12 @@
package pgx
import (
"bytes"
"encoding/hex"
"fmt"
"math"
"strconv"
"strings"
"time"
"unsafe"
)
@@ -419,6 +421,81 @@ func (n NullTime) Encode(w *WriteBuf, oid Oid) error {
return encodeTimestampTz(w, n.Time)
}
// NullHstore represents an hstore that can be null or have null values
// associated with its keys. NullHstore implements the Scanner and TextEncoder
// interfaces so it may be used both as an argument to Query[Row] and a
// destination for Scan for prepared and unprepared queries.
//
// If Valid is false, then the value of the entire hstore column is NULL
// If any of the NullString values in Store has Valid set to false, the key
// appears in the hstore column, but its value is explicitly set to NULL
type NullHstore struct {
Store map[string]NullString
Valid bool
}
func (h *NullHstore) Scan(vr *ValueReader) error {
//oid for hstore not standardized, so we check its type name
if vr.Type().DataTypeName != "hstore" {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into map[string]string", vr.Type().DataTypeName)))
return nil
}
if vr.Len() == -1 {
h.Valid = false
return nil
}
switch vr.Type().FormatCode {
case TextFormatCode:
store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len()))
if err != nil {
vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err)))
return nil
}
h.Valid = true
h.Store = store
return nil
case BinaryFormatCode:
vr.Fatal(ProtocolError("Can't decode binary hstore"))
return nil
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil
}
}
func (h NullHstore) FormatCode() int16 { return TextFormatCode }
func (h NullHstore) Encode(w *WriteBuf, oid Oid) error {
var buf bytes.Buffer
if !h.Valid {
w.WriteInt32(-1)
return nil
}
i := 0
for k, v := range h.Store {
i++
ks := strings.Replace(k, `\`, `\\`, -1)
ks = strings.Replace(ks, `"`, `\"`, -1)
if v.Valid {
vs := strings.Replace(v.String, `\`, `\\`, -1)
vs = strings.Replace(vs, `"`, `\"`, -1)
buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs))
} else {
buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks))
}
if i < len(h.Store) {
buf.WriteString(", ")
}
}
w.WriteInt32(int32(buf.Len()))
w.WriteBytes(buf.Bytes())
return nil
}
func decodeBool(vr *ValueReader) bool {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into bool"))
@@ -1045,6 +1122,60 @@ func encodeTimestamp(w *WriteBuf, value interface{}) error {
return encodeText(w, s)
}
func decodeHstore(vr *ValueReader) map[string]string {
//oid for hstore not standardized, so we check its type name
if vr.Type().DataTypeName != "hstore" {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into map[string]string", vr.Type().DataTypeName)))
return nil
}
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into map[string]string"))
return nil
}
switch vr.Type().FormatCode {
case TextFormatCode:
m, err := parseHstoreToMap(vr.ReadString(vr.Len()))
if err != nil {
vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err)))
return nil
}
return m
case BinaryFormatCode:
vr.Fatal(ProtocolError("Can't decode binary hstore"))
return nil
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil
}
}
func encodeHstore(w *WriteBuf, value interface{}) error {
var buf bytes.Buffer
h, ok := value.(map[string]string)
if !ok {
return fmt.Errorf("Expected map[string]string, received %T", value)
}
i := 0
for k, v := range h {
i++
ks := strings.Replace(k, `\`, `\\`, -1)
ks = strings.Replace(ks, `"`, `\"`, -1)
vs := strings.Replace(v, `\`, `\\`, -1)
vs = strings.Replace(vs, `"`, `\"`, -1)
buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs))
if i < len(h) {
buf.WriteString(", ")
}
}
w.WriteInt32(int32(buf.Len()))
w.WriteBytes(buf.Bytes())
return nil
}
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
numDims := vr.ReadInt32()
if numDims == 0 {