2
0

MessageWriter needs to be public for custom value transcoders

This commit is contained in:
Jack Christensen
2013-07-15 17:57:43 -05:00
parent 1af652ce07
commit 8df9964ce8
3 changed files with 74 additions and 66 deletions
+23 -23
View File
@@ -267,11 +267,11 @@ func (c *Connection) Prepare(name, sql string) (err error) {
// parse // parse
buf := c.getBuf() buf := c.getBuf()
w := newMessageWriter(buf) w := newMessageWriter(buf)
w.writeCString(name) w.WriteCString(name)
w.writeCString(sql) w.WriteCString(sql)
w.write(int16(0)) w.Write(int16(0))
if w.err != nil { if w.Err != nil {
return w.err return w.Err
} }
err = c.txMsg('P', buf) err = c.txMsg('P', buf)
if err != nil { if err != nil {
@@ -281,10 +281,10 @@ func (c *Connection) Prepare(name, sql string) (err error) {
// describe // describe
buf = c.getBuf() buf = c.getBuf()
w = newMessageWriter(buf) w = newMessageWriter(buf)
w.writeByte('S') w.WriteByte('S')
w.writeCString(name) w.WriteCString(name)
if w.err != nil { if w.Err != nil {
return w.err return w.Err
} }
err = c.txMsg('D', buf) err = c.txMsg('D', buf)
@@ -371,18 +371,18 @@ func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...inter
// bind // bind
buf := c.getBuf() buf := c.getBuf()
w := newMessageWriter(buf) w := newMessageWriter(buf)
w.writeCString("") w.WriteCString("")
w.writeCString(ps.Name) w.WriteCString(ps.Name)
w.write(int16(len(ps.ParameterOids))) w.Write(int16(len(ps.ParameterOids)))
for _, oid := range ps.ParameterOids { for _, oid := range ps.ParameterOids {
transcoder := ValueTranscoders[oid] transcoder := ValueTranscoders[oid]
if transcoder == nil { if transcoder == nil {
transcoder = defaultTranscoder transcoder = defaultTranscoder
} }
w.write(transcoder.EncodeFormat) w.Write(transcoder.EncodeFormat)
} }
w.write(int16(len(arguments))) w.Write(int16(len(arguments)))
for i, oid := range ps.ParameterOids { for i, oid := range ps.ParameterOids {
transcoder := ValueTranscoders[oid] transcoder := ValueTranscoders[oid]
if transcoder == nil { if transcoder == nil {
@@ -391,17 +391,17 @@ func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...inter
transcoder.EncodeTo(w, arguments[i]) transcoder.EncodeTo(w, arguments[i])
} }
w.write(int16(len(ps.FieldDescriptions))) w.Write(int16(len(ps.FieldDescriptions)))
for _, fd := range ps.FieldDescriptions { for _, fd := range ps.FieldDescriptions {
transcoder := ValueTranscoders[fd.DataType] transcoder := ValueTranscoders[fd.DataType]
if transcoder != nil && transcoder.DecodeBinary != nil { if transcoder != nil && transcoder.DecodeBinary != nil {
w.write(int16(1)) w.Write(int16(1))
} else { } else {
w.write(int16(0)) w.Write(int16(0))
} }
} }
if w.err != nil { if w.Err != nil {
return w.err return w.Err
} }
err = c.txMsg('B', buf) err = c.txMsg('B', buf)
@@ -412,11 +412,11 @@ func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...inter
// execute // execute
buf = c.getBuf() buf = c.getBuf()
w = newMessageWriter(buf) w = newMessageWriter(buf)
w.writeCString("") w.WriteCString("")
w.write(int32(0)) w.Write(int32(0))
if w.err != nil { if w.Err != nil {
return w.err return w.Err
} }
err = c.txMsg('E', buf) err = c.txMsg('E', buf)
+25 -17
View File
@@ -5,46 +5,54 @@ import (
"encoding/binary" "encoding/binary"
) )
type messageWriter struct { // MessageWriter is a helper for producing messages to send to PostgreSQL.
// To avoid verbose error handling it internally records errors and no-ops
// any calls that occur after an error. At the end of a sequence of writes
// the Err field should be checked to see if any errors occurred.
type MessageWriter struct {
buf *bytes.Buffer buf *bytes.Buffer
err error Err error
} }
func newMessageWriter(buf *bytes.Buffer) *messageWriter { func newMessageWriter(buf *bytes.Buffer) *MessageWriter {
return &messageWriter{buf: buf} return &MessageWriter{buf: buf}
} }
func (w *messageWriter) writeCString(s string) { // WriteCString writes a null-terminated string.
if w.err != nil { func (w *MessageWriter) WriteCString(s string) {
if w.Err != nil {
return return
} }
if _, w.err = w.buf.WriteString(s); w.err != nil { if _, w.Err = w.buf.WriteString(s); w.Err != nil {
return return
} }
w.err = w.buf.WriteByte(0) w.Err = w.buf.WriteByte(0)
} }
func (w *messageWriter) writeString(s string) { // WriteString writes a string without a null terminator.
if w.err != nil { func (w *MessageWriter) WriteString(s string) {
if w.Err != nil {
return return
} }
if _, w.err = w.buf.WriteString(s); w.err != nil { if _, w.Err = w.buf.WriteString(s); w.Err != nil {
return return
} }
} }
func (w *messageWriter) writeByte(b byte) { func (w *MessageWriter) WriteByte(b byte) {
if w.err != nil { if w.Err != nil {
return return
} }
w.err = w.buf.WriteByte(b) w.Err = w.buf.WriteByte(b)
} }
func (w *messageWriter) write(data interface{}) { // Write writes data in the network byte order. data can be an integer type,
if w.err != nil { // float type, or byte slice.
func (w *MessageWriter) Write(data interface{}) {
if w.Err != nil {
return return
} }
w.err = binary.Write(w.buf, binary.BigEndian, data) w.Err = binary.Write(w.buf, binary.BigEndian, data)
} }
+26 -26
View File
@@ -15,7 +15,7 @@ type ValueTranscoder struct {
// DecodeBinary decodes values returned from the server in binary format // DecodeBinary decodes values returned from the server in binary format
DecodeBinary func(*MessageReader, int32) interface{} DecodeBinary func(*MessageReader, int32) interface{}
// EncodeTo encodes values to send to the server // EncodeTo encodes values to send to the server
EncodeTo func(*messageWriter, interface{}) EncodeTo func(*MessageWriter, interface{})
// EncodeFormat is the format values are encoded for transmission. // EncodeFormat is the format values are encoded for transmission.
// 0 = text // 0 = text
// 1 = binary // 1 = binary
@@ -112,13 +112,13 @@ func decodeBoolFromBinary(mr *MessageReader, size int32) interface{} {
return b != 0 return b != 0
} }
func encodeBool(w *messageWriter, value interface{}) { func encodeBool(w *MessageWriter, value interface{}) {
v := value.(bool) v := value.(bool)
w.write(int32(1)) w.Write(int32(1))
if v { if v {
w.writeByte(1) w.WriteByte(1)
} else { } else {
w.writeByte(0) w.WriteByte(0)
} }
} }
@@ -138,10 +138,10 @@ func decodeInt8FromBinary(mr *MessageReader, size int32) interface{} {
return mr.ReadInt64() return mr.ReadInt64()
} }
func encodeInt8(w *messageWriter, value interface{}) { func encodeInt8(w *MessageWriter, value interface{}) {
v := value.(int64) v := value.(int64)
w.write(int32(8)) w.Write(int32(8))
w.write(v) w.Write(v)
} }
func decodeInt2FromText(mr *MessageReader, size int32) interface{} { func decodeInt2FromText(mr *MessageReader, size int32) interface{} {
@@ -160,10 +160,10 @@ func decodeInt2FromBinary(mr *MessageReader, size int32) interface{} {
return mr.ReadInt16() return mr.ReadInt16()
} }
func encodeInt2(w *messageWriter, value interface{}) { func encodeInt2(w *MessageWriter, value interface{}) {
v := value.(int16) v := value.(int16)
w.write(int32(2)) w.Write(int32(2))
w.write(v) w.Write(v)
} }
func decodeInt4FromText(mr *MessageReader, size int32) interface{} { func decodeInt4FromText(mr *MessageReader, size int32) interface{} {
@@ -182,10 +182,10 @@ func decodeInt4FromBinary(mr *MessageReader, size int32) interface{} {
return mr.ReadInt32() return mr.ReadInt32()
} }
func encodeInt4(w *messageWriter, value interface{}) { func encodeInt4(w *MessageWriter, value interface{}) {
v := value.(int32) v := value.(int32)
w.write(int32(4)) w.Write(int32(4))
w.write(v) w.Write(v)
} }
func decodeFloat4FromText(mr *MessageReader, size int32) interface{} { func decodeFloat4FromText(mr *MessageReader, size int32) interface{} {
@@ -207,10 +207,10 @@ func decodeFloat4FromBinary(mr *MessageReader, size int32) interface{} {
return *(*float32)(p) return *(*float32)(p)
} }
func encodeFloat4(w *messageWriter, value interface{}) { func encodeFloat4(w *MessageWriter, value interface{}) {
v := value.(float32) v := value.(float32)
w.write(int32(4)) w.Write(int32(4))
w.write(v) w.Write(v)
} }
func decodeFloat8FromText(mr *MessageReader, size int32) interface{} { func decodeFloat8FromText(mr *MessageReader, size int32) interface{} {
@@ -232,20 +232,20 @@ func decodeFloat8FromBinary(mr *MessageReader, size int32) interface{} {
return *(*float64)(p) return *(*float64)(p)
} }
func encodeFloat8(w *messageWriter, value interface{}) { func encodeFloat8(w *MessageWriter, value interface{}) {
v := value.(float64) v := value.(float64)
w.write(int32(8)) w.Write(int32(8))
w.write(v) w.Write(v)
} }
func decodeTextFromText(mr *MessageReader, size int32) interface{} { func decodeTextFromText(mr *MessageReader, size int32) interface{} {
return mr.ReadByteString(size) return mr.ReadByteString(size)
} }
func encodeText(w *messageWriter, value interface{}) { func encodeText(w *MessageWriter, value interface{}) {
s := value.(string) s := value.(string)
w.write(int32(len(s))) w.Write(int32(len(s)))
w.writeString(s) w.WriteString(s)
} }
func decodeByteaFromText(mr *MessageReader, size int32) interface{} { func decodeByteaFromText(mr *MessageReader, size int32) interface{} {
@@ -257,8 +257,8 @@ func decodeByteaFromText(mr *MessageReader, size int32) interface{} {
return b return b
} }
func encodeBytea(w *messageWriter, value interface{}) { func encodeBytea(w *MessageWriter, value interface{}) {
b := value.([]byte) b := value.([]byte)
w.write(int32(len(b))) w.Write(int32(len(b)))
w.write(b) w.Write(b)
} }