From fa4c70907cb8e6dbb76cc4a8d31ff6be89dcebcb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 17 Apr 2013 08:26:01 -0500 Subject: [PATCH] Rename pgx.conn to pgx.Connection --- conn.go | 76 ++++++++++++++++++++++++------------------------ conn_test.go | 38 ++++++++++++------------ sanitize.go | 4 +-- sanitize_test.go | 4 +-- 4 files changed, 61 insertions(+), 61 deletions(-) diff --git a/conn.go b/conn.go index 1e705571..04117398 100644 --- a/conn.go +++ b/conn.go @@ -11,7 +11,7 @@ import ( "strconv" ) -type conn struct { +type Connection struct { conn net.Conn // the underlying TCP or unix domain socket connection buf []byte // work buffer to avoid constant alloc and dealloc pid int32 // backend pid @@ -24,8 +24,8 @@ type conn struct { // options: // socket: path to unix domain socket // database: name of database -func Connect(options map[string]string) (c *conn, err error) { - c = new(conn) +func Connect(options map[string]string) (c *Connection, err error) { + c = new(Connection) c.options = make(map[string]string) for k, v := range options { @@ -82,7 +82,7 @@ func Connect(options map[string]string) (c *conn, err error) { panic("Unreachable") } -func (c *conn) Close() (err error) { +func (c *Connection) Close() (err error) { buf := c.getBuf(5) buf[0] = 'X' binary.BigEndian.PutUint32(buf[1:], 4) @@ -90,7 +90,7 @@ func (c *conn) Close() (err error) { return } -func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) { +func (c *Connection) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) { if err = c.sendSimpleQuery(sql); err != nil { return } @@ -128,7 +128,7 @@ func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescripti panic("Unreachable") } -func (c *conn) Query(sql string) (rows []map[string]string, err error) { +func (c *Connection) Query(sql string) (rows []map[string]string, err error) { rows = make([]map[string]string, 0, 8) onDataRow := func(r *messageReader, fields []fieldDescription) error { rows = append(rows, c.rxDataRow(r, fields)) @@ -138,7 +138,7 @@ func (c *conn) Query(sql string) (rows []map[string]string, err error) { return } -func (c *conn) SelectString(sql string) (s string, err error) { +func (c *Connection) SelectString(sql string) (s string, err error) { onDataRow := func(r *messageReader, _ []fieldDescription) error { s = c.rxDataRowFirstValue(r) return nil @@ -147,7 +147,7 @@ func (c *conn) SelectString(sql string) (s string, err error) { return } -func (c *conn) selectInt(sql string, size int) (i int64, err error) { +func (c *Connection) selectInt(sql string, size int) (i int64, err error) { var s string s, err = c.SelectString(sql) if err != nil { @@ -158,25 +158,25 @@ func (c *conn) selectInt(sql string, size int) (i int64, err error) { return } -func (c *conn) SelectInt64(sql string) (i int64, err error) { +func (c *Connection) SelectInt64(sql string) (i int64, err error) { return c.selectInt(sql, 64) } -func (c *conn) SelectInt32(sql string) (i int32, err error) { +func (c *Connection) SelectInt32(sql string) (i int32, err error) { var i64 int64 i64, err = c.selectInt(sql, 32) i = int32(i64) return } -func (c *conn) SelectInt16(sql string) (i int16, err error) { +func (c *Connection) SelectInt16(sql string) (i int16, err error) { var i64 int64 i64, err = c.selectInt(sql, 16) i = int16(i64) return } -func (c *conn) selectFloat(sql string, size int) (f float64, err error) { +func (c *Connection) selectFloat(sql string, size int) (f float64, err error) { var s string s, err = c.SelectString(sql) if err != nil { @@ -187,18 +187,18 @@ func (c *conn) selectFloat(sql string, size int) (f float64, err error) { return } -func (c *conn) SelectFloat64(sql string) (f float64, err error) { +func (c *Connection) SelectFloat64(sql string) (f float64, err error) { return c.selectFloat(sql, 64) } -func (c *conn) SelectFloat32(sql string) (f float32, err error) { +func (c *Connection) SelectFloat32(sql string) (f float32, err error) { var f64 float64 f64, err = c.selectFloat(sql, 32) f = float32(f64) return } -func (c *conn) SelectAllString(sql string) (strings []string, err error) { +func (c *Connection) SelectAllString(sql string) (strings []string, err error) { strings = make([]string, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) error { strings = append(strings, c.rxDataRowFirstValue(r)) @@ -208,7 +208,7 @@ func (c *conn) SelectAllString(sql string) (strings []string, err error) { return } -func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) { +func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { ints = make([]int64, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { var i int64 @@ -220,7 +220,7 @@ func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) { return } -func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) { +func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { ints = make([]int32, 0, 8) onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) { var i int64 @@ -232,7 +232,7 @@ func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) { return } -func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) { +func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { ints = make([]int16, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { var i int64 @@ -244,7 +244,7 @@ func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) { return } -func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) { +func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) { floats = make([]float64, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { var f float64 @@ -256,7 +256,7 @@ func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) { return } -func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) { +func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) { floats = make([]float32, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { var f float64 @@ -268,7 +268,7 @@ func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) { return } -func (c *conn) sendSimpleQuery(sql string) (err error) { +func (c *Connection) sendSimpleQuery(sql string) (err error) { bufSize := 5 + len(sql) + 1 // message identifier (1), message size (4), null string terminator (1) buf := c.getBuf(bufSize) buf[0] = 'Q' @@ -280,7 +280,7 @@ func (c *conn) sendSimpleQuery(sql string) (err error) { return err } -func (c *conn) Execute(sql string) (commandTag string, err error) { +func (c *Connection) Execute(sql string) (commandTag string, err error) { if err = c.sendSimpleQuery(sql); err != nil { return } @@ -312,7 +312,7 @@ func (c *conn) Execute(sql string) (commandTag string, err error) { // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages // is the same regardless of when they occur. -func (c *conn) processContextFreeMsg(t byte, r *messageReader) (err error) { +func (c *Connection) processContextFreeMsg(t byte, r *messageReader) (err error) { switch t { case 'S': c.rxParameterStatus(r) @@ -329,7 +329,7 @@ func (c *conn) processContextFreeMsg(t byte, r *messageReader) (err error) { } -func (c *conn) rxMsg() (t byte, r *messageReader, err error) { +func (c *Connection) rxMsg() (t byte, r *messageReader, err error) { var bodySize int32 t, bodySize, err = c.rxMsgHeader() if err != nil { @@ -345,7 +345,7 @@ func (c *conn) rxMsg() (t byte, r *messageReader, err error) { return } -func (c *conn) rxMsgHeader() (t byte, bodySize int32, err error) { +func (c *Connection) rxMsgHeader() (t byte, bodySize int32, err error) { buf := c.buf[:5] if _, err = io.ReadFull(c.conn, buf); err != nil { return 0, 0, err @@ -356,13 +356,13 @@ func (c *conn) rxMsgHeader() (t byte, bodySize int32, err error) { return t, bodySize, nil } -func (c *conn) rxMsgBody(bodySize int32) (buf []byte, err error) { +func (c *Connection) rxMsgBody(bodySize int32) (buf []byte, err error) { buf = c.getBuf(int(bodySize)) _, err = io.ReadFull(c.conn, buf) return } -func (c *conn) rxAuthenticationX(r *messageReader) (err error) { +func (c *Connection) rxAuthenticationX(r *messageReader) (err error) { code := r.readInt32() switch code { case 0: // AuthenticationOk @@ -385,13 +385,13 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (c *conn) rxParameterStatus(r *messageReader) { +func (c *Connection) rxParameterStatus(r *messageReader) { key := r.readString() value := r.readString() c.runtimeParams[key] = value } -func (c *conn) rxErrorResponse(r *messageReader) (err PgError) { +func (c *Connection) rxErrorResponse(r *messageReader) (err PgError) { for { switch r.readByte() { case 'S': @@ -410,16 +410,16 @@ func (c *conn) rxErrorResponse(r *messageReader) (err PgError) { panic("Unreachable") } -func (c *conn) rxBackendKeyData(r *messageReader) { +func (c *Connection) rxBackendKeyData(r *messageReader) { c.pid = r.readInt32() c.secretKey = r.readInt32() } -func (c *conn) rxReadyForQuery(r *messageReader) { +func (c *Connection) rxReadyForQuery(r *messageReader) { c.txStatus = r.readByte() } -func (c *conn) rxRowDescription(r *messageReader) (fields []fieldDescription) { +func (c *Connection) rxRowDescription(r *messageReader) (fields []fieldDescription) { fieldCount := r.readInt16() fields = make([]fieldDescription, fieldCount) for i := int16(0); i < fieldCount; i++ { @@ -435,7 +435,7 @@ func (c *conn) rxRowDescription(r *messageReader) (fields []fieldDescription) { return } -func (c *conn) rxDataRow(r *messageReader, fields []fieldDescription) (row map[string]string) { +func (c *Connection) rxDataRow(r *messageReader, fields []fieldDescription) (row map[string]string) { fieldCount := r.readInt16() row = make(map[string]string, fieldCount) @@ -447,7 +447,7 @@ func (c *conn) rxDataRow(r *messageReader, fields []fieldDescription) (row map[s return } -func (c *conn) rxDataRowFirstValue(r *messageReader) (s string) { +func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string) { r.readInt16() // ignore field count // TODO - handle nulls @@ -456,16 +456,16 @@ func (c *conn) rxDataRowFirstValue(r *messageReader) (s string) { return s } -func (c *conn) rxCommandComplete(r *messageReader) string { +func (c *Connection) rxCommandComplete(r *messageReader) string { return r.readString() } -func (c *conn) txStartupMessage(msg *startupMessage) (err error) { +func (c *Connection) txStartupMessage(msg *startupMessage) (err error) { _, err = c.conn.Write(msg.Bytes()) return } -func (c *conn) txPasswordMessage(password string) (err error) { +func (c *Connection) txPasswordMessage(password string) (err error) { bufSize := 5 + len(password) + 1 // message identifier (1), message size (4), password, null string terminator (1) buf := c.getBuf(bufSize) buf[0] = 'p' @@ -479,7 +479,7 @@ func (c *conn) txPasswordMessage(password string) (err error) { // Gets a []byte of n length. If possible it will reuse the connection buffer // otherwise it will allocate a new buffer -func (c *conn) getBuf(n int) (buf []byte) { +func (c *Connection) getBuf(n int) (buf []byte) { if n <= cap(c.buf) { buf = c.buf[:n] } else { diff --git a/conn_test.go b/conn_test.go index c638c39c..e6891511 100644 --- a/conn_test.go +++ b/conn_test.go @@ -5,18 +5,18 @@ import ( "testing" ) -var sharedConn *conn +var SharedConnection *Connection -func getSharedConn() (c *conn) { - if sharedConn == nil { +func getSharedConnection() (c *Connection) { + if SharedConnection == nil { var err error - sharedConn, err = Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"}) + SharedConnection, err = Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"}) if err != nil { panic("Unable to establish connection") } } - return sharedConn + return SharedConnection } func TestConnect(t *testing.T) { @@ -87,7 +87,7 @@ func TestConnectWithMD5Password(t *testing.T) { } func TestExecute(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() results, err := conn.Execute("create temporary table foo(id serial primary key);") if err != nil { @@ -116,7 +116,7 @@ func TestExecute(t *testing.T) { } func TestQuery(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() rows, err := conn.Query("select 'Jack' as name") if err != nil { @@ -133,7 +133,7 @@ func TestQuery(t *testing.T) { } func TestSelectString(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() s, err := conn.SelectString("select 'foo'") if err != nil { @@ -146,7 +146,7 @@ func TestSelectString(t *testing.T) { } func TestSelectInt64(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() i, err := conn.SelectInt64("select 1") if err != nil { @@ -169,7 +169,7 @@ func TestSelectInt64(t *testing.T) { } func TestSelectInt32(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() i, err := conn.SelectInt32("select 1") if err != nil { @@ -192,7 +192,7 @@ func TestSelectInt32(t *testing.T) { } func TestSelectInt16(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() i, err := conn.SelectInt16("select 1") if err != nil { @@ -215,7 +215,7 @@ func TestSelectInt16(t *testing.T) { } func TestSelectFloat64(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() f, err := conn.SelectFloat64("select 1.23") if err != nil { @@ -228,7 +228,7 @@ func TestSelectFloat64(t *testing.T) { } func TestSelectFloat32(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() f, err := conn.SelectFloat32("select 1.23") if err != nil { @@ -241,7 +241,7 @@ func TestSelectFloat32(t *testing.T) { } func TestSelectAllString(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() s, err := conn.SelectAllString("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t") if err != nil { @@ -254,7 +254,7 @@ func TestSelectAllString(t *testing.T) { } func TestSelectAllInt64(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() i, err := conn.SelectAllInt64("select * from (values (1), (2)) t") if err != nil { @@ -277,7 +277,7 @@ func TestSelectAllInt64(t *testing.T) { } func TestSelectAllInt32(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() i, err := conn.SelectAllInt32("select * from (values (1), (2)) t") if err != nil { @@ -300,7 +300,7 @@ func TestSelectAllInt32(t *testing.T) { } func TestSelectAllInt16(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() i, err := conn.SelectAllInt16("select * from (values (1), (2)) t") if err != nil { @@ -323,7 +323,7 @@ func TestSelectAllInt16(t *testing.T) { } func TestSelectAllFloat64(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() f, err := conn.SelectAllFloat64("select * from (values (1.23), (4.56)) t") if err != nil { @@ -336,7 +336,7 @@ func TestSelectAllFloat64(t *testing.T) { } func TestSelectAllFloat32(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() f, err := conn.SelectAllFloat32("select * from (values (1.23), (4.56)) t") if err != nil { diff --git a/sanitize.go b/sanitize.go index 871b00cd..a90fdcea 100644 --- a/sanitize.go +++ b/sanitize.go @@ -9,12 +9,12 @@ import ( var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`) -func (c *conn) QuoteString(input string) (output string) { +func (c *Connection) QuoteString(input string) (output string) { output = "'" + strings.Replace(input, "'", "''", -1) + "'" return } -func (c *conn) SanitizeSql(sql string, args ...interface{}) (output string) { +func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string) { replacer := func(match string) (replacement string) { n, _ := strconv.ParseInt(match[1:], 10, 0) switch arg := args[n-1].(type) { diff --git a/sanitize_test.go b/sanitize_test.go index 06282eb7..56f552cc 100644 --- a/sanitize_test.go +++ b/sanitize_test.go @@ -5,7 +5,7 @@ import ( ) func TestQuoteString(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() if conn.QuoteString("test") != "'test'" { t.Error("Failed to quote string") @@ -17,7 +17,7 @@ func TestQuoteString(t *testing.T) { } func TestSanitizeSql(t *testing.T) { - conn := getSharedConn() + conn := getSharedConnection() if conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23) != "select 'Jack''s', 42, 1.23" { t.Error("Failed to sanitize sql")