From 5073a3b9e0443f07ee20808d6bddb41a3e5db76a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 1 Jul 2013 15:41:20 -0500 Subject: [PATCH] Dirty, but somewhat working prepared statements and extended protocol --- connection.go | 214 +++++++++++++++++++++++++++++++++++++++++--- connection_test.go | 44 +++++++++ data_row_reader.go | 13 ++- messages.go | 19 ++-- sanitize.go | 5 ++ value_transcoder.go | 21 ++--- 6 files changed, 285 insertions(+), 31 deletions(-) diff --git a/connection.go b/connection.go index 980fbedf..e08ab281 100644 --- a/connection.go +++ b/connection.go @@ -9,6 +9,8 @@ import ( "fmt" "io" "net" + "reflect" + "strconv" ) type ConnectionParameters struct { @@ -21,13 +23,20 @@ type ConnectionParameters struct { } type Connection struct { - conn net.Conn // the underlying TCP or unix domain socket connection - buf *bytes.Buffer // work buffer to avoid constant alloc and dealloc - pid int32 // backend pid - secretKey int32 // key to use to send a cancel query message to the server - runtimeParams map[string]string // parameters that have been reported by the server - parameters ConnectionParameters // parameters used when establishing this connection - txStatus byte + conn net.Conn // the underlying TCP or unix domain socket connection + buf *bytes.Buffer // work buffer to avoid constant alloc and dealloc + pid int32 // backend pid + secretKey int32 // key to use to send a cancel query message to the server + runtimeParams map[string]string // parameters that have been reported by the server + parameters ConnectionParameters // parameters used when establishing this connection + txStatus byte + preparedStatements map[string]*PreparedStatement +} + +type PreparedStatement struct { + Name string + FieldDescriptions []FieldDescription + ParameterOids []oid } type NotSingleRowError struct { @@ -71,6 +80,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { c.buf = bytes.NewBuffer(make([]byte, sharedBufferSize)) c.runtimeParams = make(map[string]string) + c.preparedStatements = make(map[string]*PreparedStatement) msg := newStartupMessage() msg.options["user"] = c.parameters.User @@ -108,12 +118,19 @@ func (c *Connection) Close() (err error) { } func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) (err error) { - if err = c.sendSimpleQuery(sql, arguments...); err != nil { + var fields []FieldDescription + + if ps, present := c.preparedStatements[sql]; present { + fields = ps.FieldDescriptions + err = c.sendPreparedQuery(ps, arguments...) + } else { + err = c.sendSimpleQuery(sql, arguments...) + } + if err != nil { return } var callbackError error - var fields []FieldDescription for { var t byte @@ -132,6 +149,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error callbackError = onDataRow(newDataRowReader(r, fields)) } case commandComplete: + case bindComplete: default: if err = c.processContextFreeMsg(t, r); err != nil { return @@ -207,6 +225,101 @@ func (c *Connection) SelectValues(sql string, arguments ...interface{}) (values return } +func (c *Connection) Prepare(name, sql string) (err error) { + // parse + buf := c.getBuf() + _, err = buf.WriteString(name) + if err != nil { + return + } + err = buf.WriteByte(0) + if err != nil { + return + } + _, err = buf.WriteString(sql) + if err != nil { + return + } + err = buf.WriteByte(0) + if err != nil { + return + } + err = binary.Write(buf, binary.BigEndian, int16(0)) + if err != nil { + return + } + + err = c.txMsg('P', buf) + if err != nil { + return + } + + // describe + buf = c.getBuf() + err = buf.WriteByte('S') + if err != nil { + return + } + _, err = buf.WriteString(name) + if err != nil { + return + } + err = buf.WriteByte(0) + if err != nil { + return + } + + err = c.txMsg('D', buf) + if err != nil { + return + } + + // sync + err = c.txMsg('S', c.getBuf()) + if err != nil { + return err + } + + ps := PreparedStatement{Name: name} + + for { + var t byte + var r *MessageReader + if t, r, err = c.rxMsg(); err == nil { + switch t { + case parseComplete: + case parameterDescription: + ps.ParameterOids = c.rxParameterDescription(r) + case rowDescription: + ps.FieldDescriptions = c.rxRowDescription(r) + case readyForQuery: + c.preparedStatements[name] = &ps + return + default: + if err = c.processContextFreeMsg(t, r); err != nil { + return + } + } + } else { + return + } + } +} + +func (c *Connection) Deallocate(name string) (err error) { + delete(c.preparedStatements, name) + _, err = c.Execute("deallocate " + c.QuoteIdentifier(name)) + return +} + +func (c *Connection) sendQuery(sql string, arguments ...interface{}) (err error) { + if ps, present := c.preparedStatements[sql]; present { + return c.sendPreparedQuery(ps, arguments...) + } else { + return c.sendSimpleQuery(sql, arguments...) + } +} + func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err error) { if len(arguments) > 0 { sql = c.SanitizeSql(sql, arguments...) @@ -226,8 +339,78 @@ func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err return c.txMsg('Q', buf) } +func (c *Connection) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) { + if len(ps.ParameterOids) != len(arguments) { + return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) + } + + // bind + buf := c.getBuf() + buf.WriteString("") + buf.WriteByte(0) + buf.WriteString(ps.Name) + buf.WriteByte(0) + binary.Write(buf, binary.BigEndian, int16(0)) + binary.Write(buf, binary.BigEndian, int16(len(arguments))) + for _, iArg := range arguments { + var s string + switch arg := iArg.(type) { + case string: + s = arg + case int16: + s = strconv.FormatInt(int64(arg), 10) + case int32: + s = strconv.FormatInt(int64(arg), 10) + case int64: + s = strconv.FormatInt(int64(arg), 10) + case float32: + s = strconv.FormatFloat(float64(arg), 'f', -1, 32) + case float64: + s = strconv.FormatFloat(arg, 'f', -1, 64) + case []byte: + s = `E'\\x` + hex.EncodeToString(arg) + `'` + default: + panic("Unable to encode type: " + reflect.TypeOf(arg).String()) + } + binary.Write(buf, binary.BigEndian, int32(len(s))) + buf.WriteString(s) + } + // for _, pd := range ps.ParameterOids { + // transcoder := valueTranscoders[pd] + // if transcoder == nil { + // return + // } + // } + binary.Write(buf, binary.BigEndian, int16(0)) + + err = c.txMsg('B', buf) + if err != nil { + return err + } + + // execute + buf = c.getBuf() + buf.WriteString("") + buf.WriteByte(0) + binary.Write(buf, binary.BigEndian, int32(0)) + + err = c.txMsg('E', buf) + if err != nil { + return err + } + + // sync + err = c.txMsg('S', c.getBuf()) + if err != nil { + return err + } + + return + +} + func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag string, err error) { - if err = c.sendSimpleQuery(sql, arguments...); err != nil { + if err = c.sendQuery(sql, arguments...); err != nil { return } @@ -235,11 +418,13 @@ func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag s var t byte var r *MessageReader if t, r, err = c.rxMsg(); err == nil { + // fmt.Printf("Execute received: %c\n", t) switch t { case readyForQuery: return case rowDescription: case dataRow: + case bindComplete: case commandComplete: commandTag = r.ReadString() default: @@ -378,6 +563,15 @@ func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescripti return } +func (c *Connection) rxParameterDescription(r *MessageReader) (parameters []oid) { + parameterCount := r.ReadInt16() + parameters = make([]oid, 0, parameterCount) + for i := int16(0); i < parameterCount; i++ { + parameters = append(parameters, r.ReadOid()) + } + return +} + func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]interface{}) { fieldCount := len(r.fields) diff --git a/connection_test.go b/connection_test.go index eb0adc7f..957a6423 100644 --- a/connection_test.go +++ b/connection_test.go @@ -285,3 +285,47 @@ func TestSelectValues(t *testing.T) { t.Error("Multiple columns should have returned UnexpectedColumnCountError") } } + +func TestPrepare(t *testing.T) { + conn, err := Connect(ConnectionParameters{Socket: "/private/tmp/.s.PGSQL.5432", User: "pgx_none", Database: "pgx_test"}) + if err != nil { + t.Fatal("Unable to establish connection") + } + + testTranscode := func(sql string, value interface{}) { + if err = conn.Prepare("testTranscode", sql); err != nil { + t.Errorf("Unable to prepare statement: %v", err) + return + } + defer func() { + err := conn.Deallocate("testTranscode") + if err != nil { + t.Errorf("Deallocate failed: %v", err) + } + }() + + var result interface{} + result, err = conn.SelectValue("testTranscode", value) + if err != nil { + t.Errorf("%v while running %v", err, "testTranscode") + } else { + if result != value { + t.Errorf("Expected: %#v Received: %#v", value, result) + } + } + + } + + // Test parameter encoding and decoding for simple supported data types + testTranscode("select $1::varchar", "foo") + testTranscode("select $1::text", "foo") + testTranscode("select $1::int2", int16(1)) + testTranscode("select $1::int4", int32(1)) + testTranscode("select $1::int8", int64(1)) + testTranscode("select $1::float4", float32(1.23)) + testTranscode("select $1::float8", float64(1.23)) + + // case []byte: + // s = `E'\\x` + hex.EncodeToString(arg) + `'` + +} diff --git a/data_row_reader.go b/data_row_reader.go index 55072fe0..80d73292 100644 --- a/data_row_reader.go +++ b/data_row_reader.go @@ -20,13 +20,20 @@ func newDataRowReader(mr *MessageReader, fields []FieldDescription) (r *DataRowR } func (r *DataRowReader) ReadValue() interface{} { - dataType := r.fields[r.currentFieldIdx].DataType + fieldDescription := r.fields[r.currentFieldIdx] r.currentFieldIdx++ size := r.mr.ReadInt32() if size > -1 { - if vt, present := valueTranscoders[dataType]; present { - return vt.FromText(r.mr, size) + if vt, present := valueTranscoders[fieldDescription.DataType]; present { + switch fieldDescription.FormatCode { + case 0: + return vt.DecodeText(r.mr, size) + case 1: + return vt.DecodeBinary(r.mr, size) + default: + panic("Unknown format") + } } else { return r.mr.ReadByteString(size) } diff --git a/messages.go b/messages.go index 24ba294f..3afc40ab 100644 --- a/messages.go +++ b/messages.go @@ -9,14 +9,17 @@ const ( ) const ( - backendKeyData = 'K' - authenticationX = 'R' - readyForQuery = 'Z' - rowDescription = 'T' - dataRow = 'D' - commandComplete = 'C' - errorResponse = 'E' - noticeResponse = 'N' + backendKeyData = 'K' + authenticationX = 'R' + readyForQuery = 'Z' + rowDescription = 'T' + dataRow = 'D' + commandComplete = 'C' + errorResponse = 'E' + noticeResponse = 'N' + parseComplete = '1' + parameterDescription = 't' + bindComplete = '2' ) type startupMessage struct { diff --git a/sanitize.go b/sanitize.go index 0d2b92bc..25419e97 100644 --- a/sanitize.go +++ b/sanitize.go @@ -15,6 +15,11 @@ func (c *Connection) QuoteString(input string) (output string) { return } +func (c *Connection) QuoteIdentifier(input string) (output string) { + output = `"` + strings.Replace(input, `"`, `""`, -1) + `"` + return +} + func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string) { replacer := func(match string) (replacement string) { n, _ := strconv.ParseInt(match[1:], 10, 0) diff --git a/value_transcoder.go b/value_transcoder.go index a072b976..ce9b3c90 100644 --- a/value_transcoder.go +++ b/value_transcoder.go @@ -2,14 +2,15 @@ package pgx import ( "fmt" + "io" "strconv" ) type valueTranscoder struct { - FromText func(*MessageReader, int32) interface{} - // FromBinary func(*MessageReader, int32) interface{} - // ToText func(interface{}) string - // ToBinary func(interface{}) []byte + DecodeText func(*MessageReader, int32) interface{} + DecodeBinary func(*MessageReader, int32) interface{} + EncodeTo func(io.Writer, interface{}) + EncodeFormat int16 } var valueTranscoders map[oid]*valueTranscoder @@ -18,22 +19,22 @@ func init() { valueTranscoders = make(map[oid]*valueTranscoder) // bool - valueTranscoders[oid(16)] = &valueTranscoder{FromText: decodeBoolFromText} + valueTranscoders[oid(16)] = &valueTranscoder{DecodeText: decodeBoolFromText} // int8 - valueTranscoders[oid(20)] = &valueTranscoder{FromText: decodeInt8FromText} + valueTranscoders[oid(20)] = &valueTranscoder{DecodeText: decodeInt8FromText} // int2 - valueTranscoders[oid(21)] = &valueTranscoder{FromText: decodeInt2FromText} + valueTranscoders[oid(21)] = &valueTranscoder{DecodeText: decodeInt2FromText} // int4 - valueTranscoders[oid(23)] = &valueTranscoder{FromText: decodeInt4FromText} + valueTranscoders[oid(23)] = &valueTranscoder{DecodeText: decodeInt4FromText} // float4 - valueTranscoders[oid(700)] = &valueTranscoder{FromText: decodeFloat4FromText} + valueTranscoders[oid(700)] = &valueTranscoder{DecodeText: decodeFloat4FromText} // float8 - valueTranscoders[oid(701)] = &valueTranscoder{FromText: decodeFloat8FromText} + valueTranscoders[oid(701)] = &valueTranscoder{DecodeText: decodeFloat8FromText} } func decodeBoolFromText(mr *MessageReader, size int32) interface{} {