diff --git a/conn.go b/conn.go index 08d42d46..3c350b5f 100644 --- a/conn.go +++ b/conn.go @@ -90,7 +90,7 @@ func (c *Connection) Close() (err error) { return } -func (c *Connection) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) { +func (c *Connection) Select(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) { if err = c.sendSimpleQuery(sql); err != nil { return } @@ -134,7 +134,7 @@ func (c *Connection) Query(sql string) (rows []map[string]string, err error) { rows = append(rows, c.rxDataRow(r, fields)) return nil } - err = c.query(sql, onDataRow) + err = c.Select(sql, onDataRow) return } @@ -143,7 +143,7 @@ func (c *Connection) SelectString(sql string) (s string, err error) { s = c.rxDataRowFirstValue(r) return nil } - err = c.query(sql, onDataRow) + err = c.Select(sql, onDataRow) return } @@ -204,7 +204,7 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) { strings = append(strings, c.rxDataRowFirstValue(r)) return nil } - err = c.query(sql, onDataRow) + err = c.Select(sql, onDataRow) return } @@ -216,7 +216,7 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { ints = append(ints, i) return } - err = c.query(sql, onDataRow) + err = c.Select(sql, onDataRow) return } @@ -228,7 +228,7 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { ints = append(ints, int32(i)) return } - err = c.query(sql, onDataRow) + err = c.Select(sql, onDataRow) return } @@ -240,7 +240,7 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { ints = append(ints, int16(i)) return } - err = c.query(sql, onDataRow) + err = c.Select(sql, onDataRow) return } @@ -252,7 +252,7 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) floats = append(floats, f) return } - err = c.query(sql, onDataRow) + err = c.Select(sql, onDataRow) return } @@ -264,7 +264,7 @@ func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) floats = append(floats, float32(f)) return } - err = c.query(sql, onDataRow) + err = c.Select(sql, onDataRow) return } diff --git a/conn_test.go b/conn_test.go index e6891511..870bbd30 100644 --- a/conn_test.go +++ b/conn_test.go @@ -115,6 +115,24 @@ func TestExecute(t *testing.T) { } } +func TestSelect(t *testing.T) { + conn := getSharedConnection() + + rowCount := 0 + onDataRow := func(r *messageReader, fields []fieldDescription) error { + rowCount++ + return nil + } + + err := conn.Select("select generate_series(1,10)", onDataRow) + if err != nil { + t.Fatal("Select failed: " + err.Error()) + } + if rowCount != 10 { + t.Fatal("Select called onDataRow wrong number of times") + } +} + func TestQuery(t *testing.T) { conn := getSharedConnection()