diff --git a/connection.go b/connection.go index 4dba463d..70dc17ae 100644 --- a/connection.go +++ b/connection.go @@ -109,8 +109,8 @@ func (c *Connection) Close() (err error) { return c.txMsg('X', c.getBuf()) } -func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error) (err error) { - if err = c.sendSimpleQuery(sql); err != nil { +func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) (err error) { + if err = c.sendSimpleQuery(sql, arguments...); err != nil { return } @@ -147,18 +147,18 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error panic("Unreachable") } -func (c *Connection) SelectRows(sql string) (rows []map[string]interface{}, err error) { +func (c *Connection) SelectRows(sql string, arguments ...interface{}) (rows []map[string]interface{}, err error) { rows = make([]map[string]interface{}, 0, 8) onDataRow := func(r *DataRowReader) error { rows = append(rows, c.rxDataRow(r)) return nil } - err = c.SelectFunc(sql, onDataRow) + err = c.SelectFunc(sql, onDataRow, arguments...) return } // Returns a NotSingleRowError if exactly one row is not found -func (c *Connection) SelectRow(sql string) (row map[string]interface{}, err error) { +func (c *Connection) SelectRow(sql string, arguments ...interface{}) (row map[string]interface{}, err error) { var numRowsFound int64 onDataRow := func(r *DataRowReader) error { @@ -166,7 +166,7 @@ func (c *Connection) SelectRow(sql string) (row map[string]interface{}, err erro row = c.rxDataRow(r) return nil } - err = c.SelectFunc(sql, onDataRow) + err = c.SelectFunc(sql, onDataRow, arguments...) if err == nil && numRowsFound != 1 { err = NotSingleRowError{RowCount: numRowsFound} } @@ -175,7 +175,7 @@ func (c *Connection) SelectRow(sql string) (row map[string]interface{}, err erro // Returns a UnexpectedColumnCountError if exactly one column is not found // Returns a NotSingleRowError if exactly one row is not found -func (c *Connection) SelectValue(sql string) (v interface{}, err error) { +func (c *Connection) SelectValue(sql string, arguments ...interface{}) (v interface{}, err error) { var numRowsFound int64 onDataRow := func(r *DataRowReader) error { @@ -187,7 +187,7 @@ func (c *Connection) SelectValue(sql string) (v interface{}, err error) { v = r.ReadValue() return nil } - err = c.SelectFunc(sql, onDataRow) + err = c.SelectFunc(sql, onDataRow, arguments...) if err == nil { if numRowsFound != 1 { err = NotSingleRowError{RowCount: numRowsFound} @@ -197,7 +197,7 @@ func (c *Connection) SelectValue(sql string) (v interface{}, err error) { } // Returns a UnexpectedColumnCountError if exactly one column is not found -func (c *Connection) SelectValues(sql string) (values []interface{}, err error) { +func (c *Connection) SelectValues(sql string, arguments ...interface{}) (values []interface{}, err error) { values = make([]interface{}, 0, 8) onDataRow := func(r *DataRowReader) error { if len(r.fields) != 1 { @@ -207,11 +207,15 @@ func (c *Connection) SelectValues(sql string) (values []interface{}, err error) values = append(values, r.ReadValue()) return nil } - err = c.SelectFunc(sql, onDataRow) + err = c.SelectFunc(sql, onDataRow, arguments...) return } -func (c *Connection) sendSimpleQuery(sql string) (err error) { +func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err error) { + if len(arguments) > 0 { + sql = c.SanitizeSql(sql, arguments...) + } + buf := c.getBuf() _, err = buf.WriteString(sql) @@ -226,8 +230,8 @@ func (c *Connection) sendSimpleQuery(sql string) (err error) { return c.txMsg('Q', buf) } -func (c *Connection) Execute(sql string) (commandTag string, err error) { - if err = c.sendSimpleQuery(sql); err != nil { +func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag string, err error) { + if err = c.sendSimpleQuery(sql, arguments...); err != nil { return } diff --git a/connection_test.go b/connection_test.go index efb3952f..eb0adc7f 100644 --- a/connection_test.go +++ b/connection_test.go @@ -100,7 +100,7 @@ func TestConnectWithMD5Password(t *testing.T) { func TestExecute(t *testing.T) { conn := getSharedConnection() - results, err := conn.Execute("create temporary table foo(id serial primary key);") + results, err := conn.Execute("create temporary table foo(id integer primary key);") if err != nil { t.Fatal("Execute failed: " + err.Error()) } @@ -108,6 +108,15 @@ func TestExecute(t *testing.T) { t.Error("Unexpected results from Execute") } + // Accept parameters + results, err = conn.Execute("insert into foo(id) values($1)", 1) + if err != nil { + t.Errorf("Execute failed: %v", err) + } + if results != "INSERT 0 1" { + t.Errorf("Unexpected results from Execute: %v", results) + } + results, err = conn.Execute("drop table foo;") if err != nil { t.Fatal("Execute failed: " + err.Error()) @@ -124,6 +133,7 @@ func TestExecute(t *testing.T) { if results != "DROP TABLE" { t.Error("Unexpected results from Execute") } + } func TestSelectFunc(t *testing.T) { @@ -136,7 +146,7 @@ func TestSelectFunc(t *testing.T) { return nil } - err := conn.SelectFunc("select generate_series(1,10)", onDataRow) + err := conn.SelectFunc("select generate_series(1,$1)", onDataRow, 10) if err != nil { t.Fatal("Select failed: " + err.Error()) } @@ -151,7 +161,7 @@ func TestSelectFunc(t *testing.T) { func TestSelectRows(t *testing.T) { conn := getSharedConnection() - rows, err := conn.SelectRows("select 'Jack' as name, null as position") + rows, err := conn.SelectRows("select $1 as name, null as position", "Jack") if err != nil { t.Fatal("Query failed") } @@ -176,7 +186,7 @@ func TestSelectRows(t *testing.T) { func TestSelectRow(t *testing.T) { conn := getSharedConnection() - row, err := conn.SelectRow("select 'Jack' as name, null as position") + row, err := conn.SelectRow("select $1 as name, null as position", "Jack") if err != nil { t.Fatal("Query failed") } @@ -207,8 +217,8 @@ func TestSelectRow(t *testing.T) { func TestConnectionSelectValue(t *testing.T) { conn := getSharedConnection() - test := func(sql string, expected interface{}) { - v, err := conn.SelectValue(sql) + test := func(sql string, expected interface{}, arguments ...interface{}) { + v, err := conn.SelectValue(sql, arguments...) if err != nil { t.Errorf("%v while running %v", err, sql) } else { @@ -218,6 +228,7 @@ func TestConnectionSelectValue(t *testing.T) { } } + test("select $1", "foo", "foo") test("select 'foo'", "foo") test("select true", true) test("select false", false) @@ -246,8 +257,8 @@ func TestConnectionSelectValue(t *testing.T) { func TestSelectValues(t *testing.T) { conn := getSharedConnection() - test := func(sql string, expected []interface{}) { - values, err := conn.SelectValues(sql) + test := func(sql string, expected []interface{}, arguments ...interface{}) { + values, err := conn.SelectValues(sql, arguments...) if err != nil { t.Errorf("%v while running %v", err, sql) return @@ -264,6 +275,7 @@ func TestSelectValues(t *testing.T) { } } + test("select * from (values ($1)) t", []interface{}{"Matthew"}, "Matthew") test("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t", []interface{}{"Matthew", "Mark", "Luke", "John"}) test("select * from (values ('Matthew'), (null)) t", []interface{}{"Matthew", nil}) test("select * from (values (1::int4), (2::int4), (null), (3::int4)) t", []interface{}{int32(1), int32(2), nil, int32(3)})