+17
-13
@@ -109,8 +109,8 @@ func (c *Connection) Close() (err error) {
|
||||
return c.txMsg('X', c.getBuf())
|
||||
}
|
||||
|
||||
func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error) (err error) {
|
||||
if err = c.sendSimpleQuery(sql); err != nil {
|
||||
func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) (err error) {
|
||||
if err = c.sendSimpleQuery(sql, arguments...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -147,18 +147,18 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error
|
||||
panic("Unreachable")
|
||||
}
|
||||
|
||||
func (c *Connection) SelectRows(sql string) (rows []map[string]interface{}, err error) {
|
||||
func (c *Connection) SelectRows(sql string, arguments ...interface{}) (rows []map[string]interface{}, err error) {
|
||||
rows = make([]map[string]interface{}, 0, 8)
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
rows = append(rows, c.rxDataRow(r))
|
||||
return nil
|
||||
}
|
||||
err = c.SelectFunc(sql, onDataRow)
|
||||
err = c.SelectFunc(sql, onDataRow, arguments...)
|
||||
return
|
||||
}
|
||||
|
||||
// Returns a NotSingleRowError if exactly one row is not found
|
||||
func (c *Connection) SelectRow(sql string) (row map[string]interface{}, err error) {
|
||||
func (c *Connection) SelectRow(sql string, arguments ...interface{}) (row map[string]interface{}, err error) {
|
||||
var numRowsFound int64
|
||||
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
@@ -166,7 +166,7 @@ func (c *Connection) SelectRow(sql string) (row map[string]interface{}, err erro
|
||||
row = c.rxDataRow(r)
|
||||
return nil
|
||||
}
|
||||
err = c.SelectFunc(sql, onDataRow)
|
||||
err = c.SelectFunc(sql, onDataRow, arguments...)
|
||||
if err == nil && numRowsFound != 1 {
|
||||
err = NotSingleRowError{RowCount: numRowsFound}
|
||||
}
|
||||
@@ -175,7 +175,7 @@ func (c *Connection) SelectRow(sql string) (row map[string]interface{}, err erro
|
||||
|
||||
// Returns a UnexpectedColumnCountError if exactly one column is not found
|
||||
// Returns a NotSingleRowError if exactly one row is not found
|
||||
func (c *Connection) SelectValue(sql string) (v interface{}, err error) {
|
||||
func (c *Connection) SelectValue(sql string, arguments ...interface{}) (v interface{}, err error) {
|
||||
var numRowsFound int64
|
||||
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
@@ -187,7 +187,7 @@ func (c *Connection) SelectValue(sql string) (v interface{}, err error) {
|
||||
v = r.ReadValue()
|
||||
return nil
|
||||
}
|
||||
err = c.SelectFunc(sql, onDataRow)
|
||||
err = c.SelectFunc(sql, onDataRow, arguments...)
|
||||
if err == nil {
|
||||
if numRowsFound != 1 {
|
||||
err = NotSingleRowError{RowCount: numRowsFound}
|
||||
@@ -197,7 +197,7 @@ func (c *Connection) SelectValue(sql string) (v interface{}, err error) {
|
||||
}
|
||||
|
||||
// Returns a UnexpectedColumnCountError if exactly one column is not found
|
||||
func (c *Connection) SelectValues(sql string) (values []interface{}, err error) {
|
||||
func (c *Connection) SelectValues(sql string, arguments ...interface{}) (values []interface{}, err error) {
|
||||
values = make([]interface{}, 0, 8)
|
||||
onDataRow := func(r *DataRowReader) error {
|
||||
if len(r.fields) != 1 {
|
||||
@@ -207,11 +207,15 @@ func (c *Connection) SelectValues(sql string) (values []interface{}, err error)
|
||||
values = append(values, r.ReadValue())
|
||||
return nil
|
||||
}
|
||||
err = c.SelectFunc(sql, onDataRow)
|
||||
err = c.SelectFunc(sql, onDataRow, arguments...)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) sendSimpleQuery(sql string) (err error) {
|
||||
func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err error) {
|
||||
if len(arguments) > 0 {
|
||||
sql = c.SanitizeSql(sql, arguments...)
|
||||
}
|
||||
|
||||
buf := c.getBuf()
|
||||
|
||||
_, err = buf.WriteString(sql)
|
||||
@@ -226,8 +230,8 @@ func (c *Connection) sendSimpleQuery(sql string) (err error) {
|
||||
return c.txMsg('Q', buf)
|
||||
}
|
||||
|
||||
func (c *Connection) Execute(sql string) (commandTag string, err error) {
|
||||
if err = c.sendSimpleQuery(sql); err != nil {
|
||||
func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag string, err error) {
|
||||
if err = c.sendSimpleQuery(sql, arguments...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+20
-8
@@ -100,7 +100,7 @@ func TestConnectWithMD5Password(t *testing.T) {
|
||||
func TestExecute(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
results, err := conn.Execute("create temporary table foo(id serial primary key);")
|
||||
results, err := conn.Execute("create temporary table foo(id integer primary key);")
|
||||
if err != nil {
|
||||
t.Fatal("Execute failed: " + err.Error())
|
||||
}
|
||||
@@ -108,6 +108,15 @@ func TestExecute(t *testing.T) {
|
||||
t.Error("Unexpected results from Execute")
|
||||
}
|
||||
|
||||
// Accept parameters
|
||||
results, err = conn.Execute("insert into foo(id) values($1)", 1)
|
||||
if err != nil {
|
||||
t.Errorf("Execute failed: %v", err)
|
||||
}
|
||||
if results != "INSERT 0 1" {
|
||||
t.Errorf("Unexpected results from Execute: %v", results)
|
||||
}
|
||||
|
||||
results, err = conn.Execute("drop table foo;")
|
||||
if err != nil {
|
||||
t.Fatal("Execute failed: " + err.Error())
|
||||
@@ -124,6 +133,7 @@ func TestExecute(t *testing.T) {
|
||||
if results != "DROP TABLE" {
|
||||
t.Error("Unexpected results from Execute")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSelectFunc(t *testing.T) {
|
||||
@@ -136,7 +146,7 @@ func TestSelectFunc(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := conn.SelectFunc("select generate_series(1,10)", onDataRow)
|
||||
err := conn.SelectFunc("select generate_series(1,$1)", onDataRow, 10)
|
||||
if err != nil {
|
||||
t.Fatal("Select failed: " + err.Error())
|
||||
}
|
||||
@@ -151,7 +161,7 @@ func TestSelectFunc(t *testing.T) {
|
||||
func TestSelectRows(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
rows, err := conn.SelectRows("select 'Jack' as name, null as position")
|
||||
rows, err := conn.SelectRows("select $1 as name, null as position", "Jack")
|
||||
if err != nil {
|
||||
t.Fatal("Query failed")
|
||||
}
|
||||
@@ -176,7 +186,7 @@ func TestSelectRows(t *testing.T) {
|
||||
func TestSelectRow(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
row, err := conn.SelectRow("select 'Jack' as name, null as position")
|
||||
row, err := conn.SelectRow("select $1 as name, null as position", "Jack")
|
||||
if err != nil {
|
||||
t.Fatal("Query failed")
|
||||
}
|
||||
@@ -207,8 +217,8 @@ func TestSelectRow(t *testing.T) {
|
||||
func TestConnectionSelectValue(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
test := func(sql string, expected interface{}) {
|
||||
v, err := conn.SelectValue(sql)
|
||||
test := func(sql string, expected interface{}, arguments ...interface{}) {
|
||||
v, err := conn.SelectValue(sql, arguments...)
|
||||
if err != nil {
|
||||
t.Errorf("%v while running %v", err, sql)
|
||||
} else {
|
||||
@@ -218,6 +228,7 @@ func TestConnectionSelectValue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
test("select $1", "foo", "foo")
|
||||
test("select 'foo'", "foo")
|
||||
test("select true", true)
|
||||
test("select false", false)
|
||||
@@ -246,8 +257,8 @@ func TestConnectionSelectValue(t *testing.T) {
|
||||
func TestSelectValues(t *testing.T) {
|
||||
conn := getSharedConnection()
|
||||
|
||||
test := func(sql string, expected []interface{}) {
|
||||
values, err := conn.SelectValues(sql)
|
||||
test := func(sql string, expected []interface{}, arguments ...interface{}) {
|
||||
values, err := conn.SelectValues(sql, arguments...)
|
||||
if err != nil {
|
||||
t.Errorf("%v while running %v", err, sql)
|
||||
return
|
||||
@@ -264,6 +275,7 @@ func TestSelectValues(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
test("select * from (values ($1)) t", []interface{}{"Matthew"}, "Matthew")
|
||||
test("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t", []interface{}{"Matthew", "Mark", "Luke", "John"})
|
||||
test("select * from (values ('Matthew'), (null)) t", []interface{}{"Matthew", nil})
|
||||
test("select * from (values (1::int4), (2::int4), (null), (3::int4)) t", []interface{}{int32(1), int32(2), nil, int32(3)})
|
||||
|
||||
Reference in New Issue
Block a user