From 8df9964ce8776e5cb96dfa75c4ad657325dafbba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 15 Jul 2013 17:57:43 -0500 Subject: [PATCH] MessageWriter needs to be public for custom value transcoders --- connection.go | 46 +++++++++++++++++++-------------------- message_writer.go | 42 +++++++++++++++++++++--------------- value_transcoder.go | 52 ++++++++++++++++++++++----------------------- 3 files changed, 74 insertions(+), 66 deletions(-) diff --git a/connection.go b/connection.go index 8a6020a6..2d5c6097 100644 --- a/connection.go +++ b/connection.go @@ -267,11 +267,11 @@ func (c *Connection) Prepare(name, sql string) (err error) { // parse buf := c.getBuf() w := newMessageWriter(buf) - w.writeCString(name) - w.writeCString(sql) - w.write(int16(0)) - if w.err != nil { - return w.err + w.WriteCString(name) + w.WriteCString(sql) + w.Write(int16(0)) + if w.Err != nil { + return w.Err } err = c.txMsg('P', buf) if err != nil { @@ -281,10 +281,10 @@ func (c *Connection) Prepare(name, sql string) (err error) { // describe buf = c.getBuf() w = newMessageWriter(buf) - w.writeByte('S') - w.writeCString(name) - if w.err != nil { - return w.err + w.WriteByte('S') + w.WriteCString(name) + if w.Err != nil { + return w.Err } err = c.txMsg('D', buf) @@ -371,18 +371,18 @@ func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...inter // bind buf := c.getBuf() w := newMessageWriter(buf) - w.writeCString("") - w.writeCString(ps.Name) - w.write(int16(len(ps.ParameterOids))) + w.WriteCString("") + w.WriteCString(ps.Name) + w.Write(int16(len(ps.ParameterOids))) for _, oid := range ps.ParameterOids { transcoder := ValueTranscoders[oid] if transcoder == nil { transcoder = defaultTranscoder } - w.write(transcoder.EncodeFormat) + w.Write(transcoder.EncodeFormat) } - w.write(int16(len(arguments))) + w.Write(int16(len(arguments))) for i, oid := range ps.ParameterOids { transcoder := ValueTranscoders[oid] if transcoder == nil { @@ -391,17 +391,17 @@ func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...inter transcoder.EncodeTo(w, arguments[i]) } - w.write(int16(len(ps.FieldDescriptions))) + w.Write(int16(len(ps.FieldDescriptions))) for _, fd := range ps.FieldDescriptions { transcoder := ValueTranscoders[fd.DataType] if transcoder != nil && transcoder.DecodeBinary != nil { - w.write(int16(1)) + w.Write(int16(1)) } else { - w.write(int16(0)) + w.Write(int16(0)) } } - if w.err != nil { - return w.err + if w.Err != nil { + return w.Err } err = c.txMsg('B', buf) @@ -412,11 +412,11 @@ func (c *Connection) sendPreparedQuery(ps *preparedStatement, arguments ...inter // execute buf = c.getBuf() w = newMessageWriter(buf) - w.writeCString("") - w.write(int32(0)) + w.WriteCString("") + w.Write(int32(0)) - if w.err != nil { - return w.err + if w.Err != nil { + return w.Err } err = c.txMsg('E', buf) diff --git a/message_writer.go b/message_writer.go index 5d7007f4..20943537 100644 --- a/message_writer.go +++ b/message_writer.go @@ -5,46 +5,54 @@ import ( "encoding/binary" ) -type messageWriter struct { +// MessageWriter is a helper for producing messages to send to PostgreSQL. +// To avoid verbose error handling it internally records errors and no-ops +// any calls that occur after an error. At the end of a sequence of writes +// the Err field should be checked to see if any errors occurred. +type MessageWriter struct { buf *bytes.Buffer - err error + Err error } -func newMessageWriter(buf *bytes.Buffer) *messageWriter { - return &messageWriter{buf: buf} +func newMessageWriter(buf *bytes.Buffer) *MessageWriter { + return &MessageWriter{buf: buf} } -func (w *messageWriter) writeCString(s string) { - if w.err != nil { +// WriteCString writes a null-terminated string. +func (w *MessageWriter) WriteCString(s string) { + if w.Err != nil { return } - if _, w.err = w.buf.WriteString(s); w.err != nil { + if _, w.Err = w.buf.WriteString(s); w.Err != nil { return } - w.err = w.buf.WriteByte(0) + w.Err = w.buf.WriteByte(0) } -func (w *messageWriter) writeString(s string) { - if w.err != nil { +// WriteString writes a string without a null terminator. +func (w *MessageWriter) WriteString(s string) { + if w.Err != nil { return } - if _, w.err = w.buf.WriteString(s); w.err != nil { + if _, w.Err = w.buf.WriteString(s); w.Err != nil { return } } -func (w *messageWriter) writeByte(b byte) { - if w.err != nil { +func (w *MessageWriter) WriteByte(b byte) { + if w.Err != nil { return } - w.err = w.buf.WriteByte(b) + w.Err = w.buf.WriteByte(b) } -func (w *messageWriter) write(data interface{}) { - if w.err != nil { +// Write writes data in the network byte order. data can be an integer type, +// float type, or byte slice. +func (w *MessageWriter) Write(data interface{}) { + if w.Err != nil { return } - w.err = binary.Write(w.buf, binary.BigEndian, data) + w.Err = binary.Write(w.buf, binary.BigEndian, data) } diff --git a/value_transcoder.go b/value_transcoder.go index 9614ea58..ad7602be 100644 --- a/value_transcoder.go +++ b/value_transcoder.go @@ -15,7 +15,7 @@ type ValueTranscoder struct { // DecodeBinary decodes values returned from the server in binary format DecodeBinary func(*MessageReader, int32) interface{} // EncodeTo encodes values to send to the server - EncodeTo func(*messageWriter, interface{}) + EncodeTo func(*MessageWriter, interface{}) // EncodeFormat is the format values are encoded for transmission. // 0 = text // 1 = binary @@ -112,13 +112,13 @@ func decodeBoolFromBinary(mr *MessageReader, size int32) interface{} { return b != 0 } -func encodeBool(w *messageWriter, value interface{}) { +func encodeBool(w *MessageWriter, value interface{}) { v := value.(bool) - w.write(int32(1)) + w.Write(int32(1)) if v { - w.writeByte(1) + w.WriteByte(1) } else { - w.writeByte(0) + w.WriteByte(0) } } @@ -138,10 +138,10 @@ func decodeInt8FromBinary(mr *MessageReader, size int32) interface{} { return mr.ReadInt64() } -func encodeInt8(w *messageWriter, value interface{}) { +func encodeInt8(w *MessageWriter, value interface{}) { v := value.(int64) - w.write(int32(8)) - w.write(v) + w.Write(int32(8)) + w.Write(v) } func decodeInt2FromText(mr *MessageReader, size int32) interface{} { @@ -160,10 +160,10 @@ func decodeInt2FromBinary(mr *MessageReader, size int32) interface{} { return mr.ReadInt16() } -func encodeInt2(w *messageWriter, value interface{}) { +func encodeInt2(w *MessageWriter, value interface{}) { v := value.(int16) - w.write(int32(2)) - w.write(v) + w.Write(int32(2)) + w.Write(v) } func decodeInt4FromText(mr *MessageReader, size int32) interface{} { @@ -182,10 +182,10 @@ func decodeInt4FromBinary(mr *MessageReader, size int32) interface{} { return mr.ReadInt32() } -func encodeInt4(w *messageWriter, value interface{}) { +func encodeInt4(w *MessageWriter, value interface{}) { v := value.(int32) - w.write(int32(4)) - w.write(v) + w.Write(int32(4)) + w.Write(v) } func decodeFloat4FromText(mr *MessageReader, size int32) interface{} { @@ -207,10 +207,10 @@ func decodeFloat4FromBinary(mr *MessageReader, size int32) interface{} { return *(*float32)(p) } -func encodeFloat4(w *messageWriter, value interface{}) { +func encodeFloat4(w *MessageWriter, value interface{}) { v := value.(float32) - w.write(int32(4)) - w.write(v) + w.Write(int32(4)) + w.Write(v) } func decodeFloat8FromText(mr *MessageReader, size int32) interface{} { @@ -232,20 +232,20 @@ func decodeFloat8FromBinary(mr *MessageReader, size int32) interface{} { return *(*float64)(p) } -func encodeFloat8(w *messageWriter, value interface{}) { +func encodeFloat8(w *MessageWriter, value interface{}) { v := value.(float64) - w.write(int32(8)) - w.write(v) + w.Write(int32(8)) + w.Write(v) } func decodeTextFromText(mr *MessageReader, size int32) interface{} { return mr.ReadByteString(size) } -func encodeText(w *messageWriter, value interface{}) { +func encodeText(w *MessageWriter, value interface{}) { s := value.(string) - w.write(int32(len(s))) - w.writeString(s) + w.Write(int32(len(s))) + w.WriteString(s) } func decodeByteaFromText(mr *MessageReader, size int32) interface{} { @@ -257,8 +257,8 @@ func decodeByteaFromText(mr *MessageReader, size int32) interface{} { return b } -func encodeBytea(w *messageWriter, value interface{}) { +func encodeBytea(w *MessageWriter, value interface{}) { b := value.([]byte) - w.write(int32(len(b))) - w.write(b) + w.Write(int32(len(b))) + w.Write(b) }