2
0

Prepare returns *PreparedStatement

This commit is contained in:
Jack Christensen
2014-06-20 10:47:31 -05:00
parent cc445627b0
commit 4efa61bf5b
3 changed files with 17 additions and 17 deletions
+12 -12
View File
@@ -59,14 +59,14 @@ type Conn struct {
RuntimeParams map[string]string // parameters that have been reported by the server RuntimeParams map[string]string // parameters that have been reported by the server
config ConnConfig // config used when establishing this connection config ConnConfig // config used when establishing this connection
TxStatus byte TxStatus byte
preparedStatements map[string]*preparedStatement preparedStatements map[string]*PreparedStatement
notifications []*Notification notifications []*Notification
alive bool alive bool
causeOfDeath error causeOfDeath error
logger log.Logger logger log.Logger
} }
type preparedStatement struct { type PreparedStatement struct {
Name string Name string
FieldDescriptions []FieldDescription FieldDescriptions []FieldDescription
ParameterOids []Oid ParameterOids []Oid
@@ -185,7 +185,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
c.bufSize = c.config.MsgBufSize c.bufSize = c.config.MsgBufSize
c.buf = bytes.NewBuffer(make([]byte, 0, c.bufSize)) c.buf = bytes.NewBuffer(make([]byte, 0, c.bufSize))
c.RuntimeParams = make(map[string]string) c.RuntimeParams = make(map[string]string)
c.preparedStatements = make(map[string]*preparedStatement) c.preparedStatements = make(map[string]*PreparedStatement)
c.alive = true c.alive = true
if config.TLSConfig != nil { if config.TLSConfig != nil {
@@ -579,7 +579,7 @@ func (c *Conn) SelectValues(sql string, arguments ...interface{}) ([]interface{}
// Prepare creates a prepared statement with name and sql. sql can contain placeholders // Prepare creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc. // for bound parameters. These placeholders are referenced positional as $1, $2, etc.
func (c *Conn) Prepare(name, sql string) (err error) { func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
defer func() { defer func() {
if err != nil { if err != nil {
c.logger.Error(fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err)) c.logger.Error(fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err))
@@ -595,7 +595,7 @@ func (c *Conn) Prepare(name, sql string) (err error) {
binary.Write(buf, binary.BigEndian, int16(0)) binary.Write(buf, binary.BigEndian, int16(0))
err = c.txMsg('P', buf, false) err = c.txMsg('P', buf, false)
if err != nil { if err != nil {
return err return nil, err
} }
// describe // describe
@@ -605,16 +605,16 @@ func (c *Conn) Prepare(name, sql string) (err error) {
buf.WriteByte(0) buf.WriteByte(0)
err = c.txMsg('D', buf, false) err = c.txMsg('D', buf, false)
if err != nil { if err != nil {
return return nil, err
} }
// sync // sync
err = c.txMsg('S', c.getBuf(), true) err = c.txMsg('S', c.getBuf(), true)
if err != nil { if err != nil {
return err return nil, err
} }
ps := preparedStatement{Name: name} ps = &PreparedStatement{Name: name}
var softErr error var softErr error
@@ -623,7 +623,7 @@ func (c *Conn) Prepare(name, sql string) (err error) {
var r *MessageReader var r *MessageReader
t, r, err := c.rxMsg() t, r, err := c.rxMsg()
if err != nil { if err != nil {
return err return nil, err
} }
switch t { switch t {
@@ -641,8 +641,8 @@ func (c *Conn) Prepare(name, sql string) (err error) {
case noData: case noData:
case readyForQuery: case readyForQuery:
c.rxReadyForQuery(r) c.rxReadyForQuery(r)
c.preparedStatements[name] = &ps c.preparedStatements[name] = ps
return softErr return ps, softErr
default: default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
softErr = e softErr = e
@@ -761,7 +761,7 @@ func (c *Conn) sendSimpleQuery(sql string, arguments ...interface{}) (err error)
return c.txMsg('Q', buf, true) return c.txMsg('Q', buf, true)
} }
func (c *Conn) sendPreparedQuery(ps *preparedStatement, arguments ...interface{}) (err error) { func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) {
if len(ps.ParameterOids) != len(arguments) { if len(ps.ParameterOids) != len(arguments) {
return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments))
} }
+4 -4
View File
@@ -520,7 +520,7 @@ func TestPrepare(t *testing.T) {
defer conn.Close() defer conn.Close()
testTranscode := func(sql string, value interface{}) { testTranscode := func(sql string, value interface{}) {
if err = conn.Prepare("testTranscode", sql); err != nil { if _, err = conn.Prepare("testTranscode", sql); err != nil {
t.Errorf("Unable to prepare statement: %v", err) t.Errorf("Unable to prepare statement: %v", err)
return return
} }
@@ -555,7 +555,7 @@ func TestPrepare(t *testing.T) {
// Ensure that unknown types are just treated as strings // Ensure that unknown types are just treated as strings
testTranscode("select $1::point", "(0,0)") testTranscode("select $1::point", "(0,0)")
if err = conn.Prepare("testByteSliceTranscode", "select $1::bytea"); err != nil { if _, err = conn.Prepare("testByteSliceTranscode", "select $1::bytea"); err != nil {
t.Errorf("Unable to prepare statement: %v", err) t.Errorf("Unable to prepare statement: %v", err)
return return
} }
@@ -588,7 +588,7 @@ func TestPrepare(t *testing.T) {
} }
mustExecute(t, conn, "create temporary table foo(id serial)") mustExecute(t, conn, "create temporary table foo(id serial)")
if err = conn.Prepare("deleteFoo", "delete from foo"); err != nil { if _, err = conn.Prepare("deleteFoo", "delete from foo"); err != nil {
t.Fatalf("Unable to prepare delete: %v", err) t.Fatalf("Unable to prepare delete: %v", err)
} }
} }
@@ -600,7 +600,7 @@ func TestPrepareFailure(t *testing.T) {
} }
defer conn.Close() defer conn.Close()
if err = conn.Prepare("badSQL", "select foo"); err == nil { if _, err = conn.Prepare("badSQL", "select foo"); err == nil {
t.Fatal("Prepare should have failed with syntax error") t.Fatal("Prepare should have failed with syntax error")
} }
+1 -1
View File
@@ -21,7 +21,7 @@ func getSharedConnection(t testing.TB) (c *pgx.Conn) {
} }
func mustPrepare(t testing.TB, conn *pgx.Conn, name, sql string) { func mustPrepare(t testing.TB, conn *pgx.Conn, name, sql string) {
if err := conn.Prepare(name, sql); err != nil { if _, err := conn.Prepare(name, sql); err != nil {
t.Fatalf("Could not prepare %v: %v", name, err) t.Fatalf("Could not prepare %v: %v", name, err)
} }
} }