2
0

SelectValue and SelectRow error if no rows found

fixes #14
This commit is contained in:
Jack Christensen
2013-06-29 13:10:24 -05:00
parent b53b014b54
commit aabf563a3b
2 changed files with 49 additions and 37 deletions
+28 -18
View File
@@ -29,6 +29,12 @@ type Connection struct {
txStatus byte txStatus byte
} }
type NoRowsFoundError struct {
msg string
}
func (e NoRowsFoundError) Error() string { return e.msg }
func Connect(parameters ConnectionParameters) (c *Connection, err error) { func Connect(parameters ConnectionParameters) (c *Connection, err error) {
c = new(Connection) c = new(Connection)
@@ -131,12 +137,8 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error
panic("Unreachable") panic("Unreachable")
} }
// Null values are not included in rows. However, because maps return the 0 value func (c *Connection) SelectRows(sql string) (rows []map[string]interface{}, err error) {
// for missing values this flattens nulls to empty string. If the caller needs to rows = make([]map[string]interface{}, 0, 8)
// distinguish between a real empty string and a null it can use the comma ok
// pattern when accessing the map
func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) {
rows = make([]map[string]string, 0, 8)
onDataRow := func(r *DataRowReader) error { onDataRow := func(r *DataRowReader) error {
rows = append(rows, c.rxDataRow(r)) rows = append(rows, c.rxDataRow(r))
return nil return nil
@@ -145,25 +147,37 @@ func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error
return return
} }
// Null values are not included in row. However, because maps return the 0 value func (c *Connection) SelectRow(sql string) (row map[string]interface{}, err error) {
// for missing values this flattens nulls to empty string. If the caller needs to var numRowsFound int64
// distinguish between a real empty string and a null it can use the comma ok
// pattern when accessing the map
func (c *Connection) SelectRow(sql string) (row map[string]string, err error) {
onDataRow := func(r *DataRowReader) error { onDataRow := func(r *DataRowReader) error {
numRowsFound++
row = c.rxDataRow(r) row = c.rxDataRow(r)
return nil return nil
} }
err = c.SelectFunc(sql, onDataRow) err = c.SelectFunc(sql, onDataRow)
if err == nil {
if numRowsFound == 0 {
err = NoRowsFoundError{}
}
}
return return
} }
func (c *Connection) SelectValue(sql string) (v interface{}, err error) { func (c *Connection) SelectValue(sql string) (v interface{}, err error) {
var numRowsFound int64
onDataRow := func(r *DataRowReader) error { onDataRow := func(r *DataRowReader) error {
numRowsFound++
v = r.ReadValue() v = r.ReadValue()
return nil return nil
} }
err = c.SelectFunc(sql, onDataRow) err = c.SelectFunc(sql, onDataRow)
if err == nil {
if numRowsFound == 0 {
err = NoRowsFoundError{}
}
}
return return
} }
@@ -344,16 +358,12 @@ func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescripti
return return
} }
func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]string) { func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]interface{}) {
fieldCount := len(r.fields) fieldCount := len(r.fields)
mr := r.mr
row = make(map[string]string, fieldCount) row = make(map[string]interface{}, fieldCount)
for i := 0; i < fieldCount; i++ { for i := 0; i < fieldCount; i++ {
size := mr.ReadInt32() row[r.fields[i].Name] = r.ReadValue()
if size > -1 {
row[r.fields[i].Name] = mr.ReadByteString(size)
}
} }
return return
} }
+21 -19
View File
@@ -36,7 +36,7 @@ func TestConnect(t *testing.T) {
t.Error("Backend secret key not stored") t.Error("Backend secret key not stored")
} }
var rows []map[string]string var rows []map[string]interface{}
rows, err = conn.SelectRows("select current_database()") rows, err = conn.SelectRows("select current_database()")
if err != nil || rows[0]["current_database"] != "pgx_test" { if err != nil || rows[0]["current_database"] != "pgx_test" {
t.Error("Did not connect to specified database (pgx_text)") t.Error("Did not connect to specified database (pgx_text)")
@@ -164,12 +164,12 @@ func TestSelectRows(t *testing.T) {
t.Error("Received incorrect name") t.Error("Received incorrect name")
} }
value, presence := rows[0]["position"] if value, presence := rows[0]["position"]; presence {
if value != "" { if value != nil {
t.Error("Should have received empty string for null") t.Error("Should have received nil for null")
} }
if presence != false { } else {
t.Error("Null value shouldn't have been present in map") t.Error("Null value should have been present in map as nil")
} }
} }
@@ -185,20 +185,17 @@ func TestSelectRow(t *testing.T) {
t.Error("Received incorrect name") t.Error("Received incorrect name")
} }
value, presence := row["position"] if value, presence := row["position"]; presence {
if value != "" { if value != nil {
t.Error("Should have received empty string for null") t.Error("Should have received nil for null")
} }
if presence != false { } else {
t.Error("Null value shouldn't have been present in map") t.Error("Null value should have been present in map as nil")
} }
row, err = conn.SelectRow("select 'Jack' as name where 1=2") _, err = conn.SelectRow("select 'Jack' as name where 1=2")
if row != nil { if _, ok := err.(NoRowsFoundError); !ok {
t.Error("No matching row should have returned nil") t.Error("No matching row should have returned NoRowsFoundError")
}
if err != nil {
t.Fatal("Query failed")
} }
} }
@@ -224,6 +221,11 @@ func TestConnectionSelectValue(t *testing.T) {
test("select 1::int8", int64(1)) test("select 1::int8", int64(1))
test("select 1.23::float4", float32(1.23)) test("select 1.23::float4", float32(1.23))
test("select 1.23::float8", float64(1.23)) test("select 1.23::float8", float64(1.23))
_, err := conn.SelectValue("select 'Jack' as name where 1=2")
if _, ok := err.(NoRowsFoundError); !ok {
t.Error("No matching row should have returned NoRowsFoundError")
}
} }
func TestSelectValues(t *testing.T) { func TestSelectValues(t *testing.T) {