From 884252675e9020a9dcee3bc9d73f33adc380a07f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 19 Jun 2014 08:03:14 -0500 Subject: [PATCH] Handle ValueTranscoder EncodeTo returns error on bad type Instead of panicking --- conn.go | 5 +- example_value_transcoder_test.go | 10 ++- value_transcoder.go | 141 ++++++++++++++++++++++++------- value_transcoder_test.go | 21 +++++ 4 files changed, 144 insertions(+), 33 deletions(-) diff --git a/conn.go b/conn.go index 14bd7671..63e32d6d 100644 --- a/conn.go +++ b/conn.go @@ -793,7 +793,10 @@ func (c *Conn) sendPreparedQuery(ps *preparedStatement, arguments ...interface{} if transcoder == nil { transcoder = defaultTranscoder } - transcoder.EncodeTo(w, arguments[i]) + err = transcoder.EncodeTo(w, arguments[i]) + if err != nil { + return err + } } else { w.Write(int32(-1)) } diff --git a/example_value_transcoder_test.go b/example_value_transcoder_test.go index 0e3e2a25..b5af0872 100644 --- a/example_value_transcoder_test.go +++ b/example_value_transcoder_test.go @@ -55,9 +55,15 @@ func decodePointFromText(mr *pgx.MessageReader, size int32) interface{} { return p } -func encodePoint(w *pgx.MessageWriter, value interface{}) { - p := value.(Point) +func encodePoint(w *pgx.MessageWriter, value interface{}) error { + p, ok := value.(Point) + if !ok { + return fmt.Errorf("Expected Point, received %T", value) + } + s := fmt.Sprintf("point(%v,%v)", p.x, p.y) w.Write(int32(len(s))) w.WriteString(s) + + return nil } diff --git a/value_transcoder.go b/value_transcoder.go index ec78e7b6..d83f0318 100644 --- a/value_transcoder.go +++ b/value_transcoder.go @@ -18,7 +18,7 @@ type ValueTranscoder struct { // DecodeBinary decodes values returned from the server in binary format DecodeBinary func(*MessageReader, int32) interface{} // EncodeTo encodes values to send to the server - EncodeTo func(*MessageWriter, interface{}) + EncodeTo func(*MessageWriter, interface{}) error // EncodeFormat is the format values are encoded for transmission. // 0 = text // 1 = binary @@ -159,14 +159,20 @@ func decodeBoolFromBinary(mr *MessageReader, size int32) interface{} { return b != 0 } -func encodeBool(w *MessageWriter, value interface{}) { - v := value.(bool) +func encodeBool(w *MessageWriter, value interface{}) error { + v, ok := value.(bool) + if !ok { + return fmt.Errorf("Expected bool, received %T", value) + } + w.Write(int32(1)) if v { w.WriteByte(1) } else { w.WriteByte(0) } + + return nil } func decodeInt8FromText(mr *MessageReader, size int32) interface{} { @@ -185,10 +191,16 @@ func decodeInt8FromBinary(mr *MessageReader, size int32) interface{} { return mr.ReadInt64() } -func encodeInt8(w *MessageWriter, value interface{}) { - v := value.(int64) +func encodeInt8(w *MessageWriter, value interface{}) error { + v, ok := value.(int64) + if !ok { + return fmt.Errorf("Expected int64, received %T", value) + } + w.Write(int32(8)) w.Write(v) + + return nil } func decodeInt2FromText(mr *MessageReader, size int32) interface{} { @@ -207,10 +219,16 @@ func decodeInt2FromBinary(mr *MessageReader, size int32) interface{} { return mr.ReadInt16() } -func encodeInt2(w *MessageWriter, value interface{}) { - v := value.(int16) +func encodeInt2(w *MessageWriter, value interface{}) error { + v, ok := value.(int16) + if !ok { + return fmt.Errorf("Expected int16, received %T", value) + } + w.Write(int32(2)) w.Write(v) + + return nil } func decodeInt4FromText(mr *MessageReader, size int32) interface{} { @@ -229,10 +247,16 @@ func decodeInt4FromBinary(mr *MessageReader, size int32) interface{} { return mr.ReadInt32() } -func encodeInt4(w *MessageWriter, value interface{}) { - v := value.(int32) +func encodeInt4(w *MessageWriter, value interface{}) error { + v, ok := value.(int32) + if !ok { + return fmt.Errorf("Expected int32, received %T", value) + } + w.Write(int32(4)) w.Write(v) + + return nil } func decodeFloat4FromText(mr *MessageReader, size int32) interface{} { @@ -254,10 +278,16 @@ func decodeFloat4FromBinary(mr *MessageReader, size int32) interface{} { return *(*float32)(p) } -func encodeFloat4(w *MessageWriter, value interface{}) { - v := value.(float32) +func encodeFloat4(w *MessageWriter, value interface{}) error { + v, ok := value.(float32) + if !ok { + return fmt.Errorf("Expected float32, received %T", value) + } + w.Write(int32(4)) w.Write(v) + + return nil } func decodeFloat8FromText(mr *MessageReader, size int32) interface{} { @@ -279,20 +309,32 @@ func decodeFloat8FromBinary(mr *MessageReader, size int32) interface{} { return *(*float64)(p) } -func encodeFloat8(w *MessageWriter, value interface{}) { - v := value.(float64) +func encodeFloat8(w *MessageWriter, value interface{}) error { + v, ok := value.(float64) + if !ok { + return fmt.Errorf("Expected float64, received %T", value) + } + w.Write(int32(8)) w.Write(v) + + return nil } func decodeTextFromText(mr *MessageReader, size int32) interface{} { return mr.ReadString(size) } -func encodeText(w *MessageWriter, value interface{}) { - s := value.(string) +func encodeText(w *MessageWriter, value interface{}) error { + s, ok := value.(string) + if !ok { + return fmt.Errorf("Expected string, received %T", value) + } + w.Write(int32(len(s))) w.WriteString(s) + + return nil } func decodeByteaFromText(mr *MessageReader, size int32) interface{} { @@ -304,10 +346,16 @@ func decodeByteaFromText(mr *MessageReader, size int32) interface{} { return b } -func encodeBytea(w *MessageWriter, value interface{}) { - b := value.([]byte) +func encodeBytea(w *MessageWriter, value interface{}) error { + b, ok := value.([]byte) + if !ok { + return fmt.Errorf("Expected []byte, received %T", value) + } + w.Write(int32(len(b))) w.Write(b) + + return nil } func decodeDateFromText(mr *MessageReader, size int32) interface{} { @@ -319,11 +367,17 @@ func decodeDateFromText(mr *MessageReader, size int32) interface{} { return t } -func encodeDate(w *MessageWriter, value interface{}) { - t := value.(time.Time) +func encodeDate(w *MessageWriter, value interface{}) error { + t, ok := value.(time.Time) + if !ok { + return fmt.Errorf("Expected time.Time, received %T", value) + } + s := t.Format("2006-01-02") w.Write(int32(len(s))) w.WriteString(s) + + return nil } func decodeTimestampTzFromText(mr *MessageReader, size int32) interface{} { @@ -349,11 +403,17 @@ func decodeTimestampTzFromBinary(mr *MessageReader, size int32) interface{} { } -func encodeTimestampTz(w *MessageWriter, value interface{}) { - t := value.(time.Time) +func encodeTimestampTz(w *MessageWriter, value interface{}) error { + t, ok := value.(time.Time) + if !ok { + return fmt.Errorf("Expected float32, received %T", value) + } + s := t.Format("2006-01-02 15:04:05.999999 -0700") w.Write(int32(len(s))) w.WriteString(s) + + return nil } func decodeInt2ArrayFromText(mr *MessageReader, size int32) interface{} { @@ -387,14 +447,21 @@ func int16SliceToArrayString(nums []int16) (string, error) { return w.buf.String(), w.Err } -func encodeInt2Array(w *MessageWriter, value interface{}) { - v := value.([]int16) +func encodeInt2Array(w *MessageWriter, value interface{}) error { + v, ok := value.([]int16) + if !ok { + return fmt.Errorf("Expected []int16, received %T", value) + } + s, err := int16SliceToArrayString(v) if err != nil { - w.Err = fmt.Errorf("Failed to encode []int16: %v", err) + return fmt.Errorf("Failed to encode []int16: %v", err) } + w.Write(int32(len(s))) w.WriteString(s) + + return nil } func decodeInt4ArrayFromText(mr *MessageReader, size int32) interface{} { @@ -428,14 +495,21 @@ func int32SliceToArrayString(nums []int32) (string, error) { return w.buf.String(), w.Err } -func encodeInt4Array(w *MessageWriter, value interface{}) { - v := value.([]int32) +func encodeInt4Array(w *MessageWriter, value interface{}) error { + v, ok := value.([]int32) + if !ok { + return fmt.Errorf("Expected []int32, received %T", value) + } + s, err := int32SliceToArrayString(v) if err != nil { - w.Err = fmt.Errorf("Failed to encode []int32: %v", err) + return fmt.Errorf("Failed to encode []int32: %v", err) } + w.Write(int32(len(s))) w.WriteString(s) + + return nil } func decodeInt8ArrayFromText(mr *MessageReader, size int32) interface{} { @@ -469,12 +543,19 @@ func int64SliceToArrayString(nums []int64) (string, error) { return w.buf.String(), w.Err } -func encodeInt8Array(w *MessageWriter, value interface{}) { - v := value.([]int64) +func encodeInt8Array(w *MessageWriter, value interface{}) error { + v, ok := value.([]int64) + if !ok { + return fmt.Errorf("Expected []int64, received %T", value) + } + s, err := int64SliceToArrayString(v) if err != nil { - w.Err = fmt.Errorf("Failed to encode []int64: %v", err) + return fmt.Errorf("Failed to encode []int64: %v", err) } + w.Write(int32(len(s))) w.WriteString(s) + + return nil } diff --git a/value_transcoder_test.go b/value_transcoder_test.go index 9f9d6be5..54444c84 100644 --- a/value_transcoder_test.go +++ b/value_transcoder_test.go @@ -5,6 +5,27 @@ import ( "time" ) +func TestTranscodeError(t *testing.T) { + conn := getSharedConnection(t) + + mustPrepare(t, conn, "testTranscode", "select $1::integer") + defer func() { + if err := conn.Deallocate("testTranscode"); err != nil { + t.Fatalf("Unable to deallocate prepared statement: %v", err) + } + }() + + _, err := conn.SelectValue("testTranscode", "wrong") + switch { + case err == nil: + t.Error("Expected transcode error to return error, but it didn't") + case err.Error() == "Expected int32, received string": + // Correct behavior + default: + t.Errorf("Expected transcode error, received %v", err) + } +} + func TestNilTranscode(t *testing.T) { conn := getSharedConnection(t)