+28
-18
@@ -29,6 +29,12 @@ type Connection struct {
|
|||||||
txStatus byte
|
txStatus byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NoRowsFoundError struct {
|
||||||
|
msg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e NoRowsFoundError) Error() string { return e.msg }
|
||||||
|
|
||||||
func Connect(parameters ConnectionParameters) (c *Connection, err error) {
|
func Connect(parameters ConnectionParameters) (c *Connection, err error) {
|
||||||
c = new(Connection)
|
c = new(Connection)
|
||||||
|
|
||||||
@@ -131,12 +137,8 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error
|
|||||||
panic("Unreachable")
|
panic("Unreachable")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Null values are not included in rows. However, because maps return the 0 value
|
func (c *Connection) SelectRows(sql string) (rows []map[string]interface{}, err error) {
|
||||||
// for missing values this flattens nulls to empty string. If the caller needs to
|
rows = make([]map[string]interface{}, 0, 8)
|
||||||
// distinguish between a real empty string and a null it can use the comma ok
|
|
||||||
// pattern when accessing the map
|
|
||||||
func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) {
|
|
||||||
rows = make([]map[string]string, 0, 8)
|
|
||||||
onDataRow := func(r *DataRowReader) error {
|
onDataRow := func(r *DataRowReader) error {
|
||||||
rows = append(rows, c.rxDataRow(r))
|
rows = append(rows, c.rxDataRow(r))
|
||||||
return nil
|
return nil
|
||||||
@@ -145,25 +147,37 @@ func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Null values are not included in row. However, because maps return the 0 value
|
func (c *Connection) SelectRow(sql string) (row map[string]interface{}, err error) {
|
||||||
// for missing values this flattens nulls to empty string. If the caller needs to
|
var numRowsFound int64
|
||||||
// distinguish between a real empty string and a null it can use the comma ok
|
|
||||||
// pattern when accessing the map
|
|
||||||
func (c *Connection) SelectRow(sql string) (row map[string]string, err error) {
|
|
||||||
onDataRow := func(r *DataRowReader) error {
|
onDataRow := func(r *DataRowReader) error {
|
||||||
|
numRowsFound++
|
||||||
row = c.rxDataRow(r)
|
row = c.rxDataRow(r)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err = c.SelectFunc(sql, onDataRow)
|
err = c.SelectFunc(sql, onDataRow)
|
||||||
|
if err == nil {
|
||||||
|
if numRowsFound == 0 {
|
||||||
|
err = NoRowsFoundError{}
|
||||||
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) SelectValue(sql string) (v interface{}, err error) {
|
func (c *Connection) SelectValue(sql string) (v interface{}, err error) {
|
||||||
|
var numRowsFound int64
|
||||||
|
|
||||||
onDataRow := func(r *DataRowReader) error {
|
onDataRow := func(r *DataRowReader) error {
|
||||||
|
numRowsFound++
|
||||||
v = r.ReadValue()
|
v = r.ReadValue()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err = c.SelectFunc(sql, onDataRow)
|
err = c.SelectFunc(sql, onDataRow)
|
||||||
|
if err == nil {
|
||||||
|
if numRowsFound == 0 {
|
||||||
|
err = NoRowsFoundError{}
|
||||||
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -344,16 +358,12 @@ func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescripti
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]string) {
|
func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]interface{}) {
|
||||||
fieldCount := len(r.fields)
|
fieldCount := len(r.fields)
|
||||||
mr := r.mr
|
|
||||||
|
|
||||||
row = make(map[string]string, fieldCount)
|
row = make(map[string]interface{}, fieldCount)
|
||||||
for i := 0; i < fieldCount; i++ {
|
for i := 0; i < fieldCount; i++ {
|
||||||
size := mr.ReadInt32()
|
row[r.fields[i].Name] = r.ReadValue()
|
||||||
if size > -1 {
|
|
||||||
row[r.fields[i].Name] = mr.ReadByteString(size)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
+21
-19
@@ -36,7 +36,7 @@ func TestConnect(t *testing.T) {
|
|||||||
t.Error("Backend secret key not stored")
|
t.Error("Backend secret key not stored")
|
||||||
}
|
}
|
||||||
|
|
||||||
var rows []map[string]string
|
var rows []map[string]interface{}
|
||||||
rows, err = conn.SelectRows("select current_database()")
|
rows, err = conn.SelectRows("select current_database()")
|
||||||
if err != nil || rows[0]["current_database"] != "pgx_test" {
|
if err != nil || rows[0]["current_database"] != "pgx_test" {
|
||||||
t.Error("Did not connect to specified database (pgx_text)")
|
t.Error("Did not connect to specified database (pgx_text)")
|
||||||
@@ -164,12 +164,12 @@ func TestSelectRows(t *testing.T) {
|
|||||||
t.Error("Received incorrect name")
|
t.Error("Received incorrect name")
|
||||||
}
|
}
|
||||||
|
|
||||||
value, presence := rows[0]["position"]
|
if value, presence := rows[0]["position"]; presence {
|
||||||
if value != "" {
|
if value != nil {
|
||||||
t.Error("Should have received empty string for null")
|
t.Error("Should have received nil for null")
|
||||||
}
|
}
|
||||||
if presence != false {
|
} else {
|
||||||
t.Error("Null value shouldn't have been present in map")
|
t.Error("Null value should have been present in map as nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,20 +185,17 @@ func TestSelectRow(t *testing.T) {
|
|||||||
t.Error("Received incorrect name")
|
t.Error("Received incorrect name")
|
||||||
}
|
}
|
||||||
|
|
||||||
value, presence := row["position"]
|
if value, presence := row["position"]; presence {
|
||||||
if value != "" {
|
if value != nil {
|
||||||
t.Error("Should have received empty string for null")
|
t.Error("Should have received nil for null")
|
||||||
}
|
}
|
||||||
if presence != false {
|
} else {
|
||||||
t.Error("Null value shouldn't have been present in map")
|
t.Error("Null value should have been present in map as nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
row, err = conn.SelectRow("select 'Jack' as name where 1=2")
|
_, err = conn.SelectRow("select 'Jack' as name where 1=2")
|
||||||
if row != nil {
|
if _, ok := err.(NoRowsFoundError); !ok {
|
||||||
t.Error("No matching row should have returned nil")
|
t.Error("No matching row should have returned NoRowsFoundError")
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Query failed")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,6 +221,11 @@ func TestConnectionSelectValue(t *testing.T) {
|
|||||||
test("select 1::int8", int64(1))
|
test("select 1::int8", int64(1))
|
||||||
test("select 1.23::float4", float32(1.23))
|
test("select 1.23::float4", float32(1.23))
|
||||||
test("select 1.23::float8", float64(1.23))
|
test("select 1.23::float8", float64(1.23))
|
||||||
|
|
||||||
|
_, err := conn.SelectValue("select 'Jack' as name where 1=2")
|
||||||
|
if _, ok := err.(NoRowsFoundError); !ok {
|
||||||
|
t.Error("No matching row should have returned NoRowsFoundError")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectValues(t *testing.T) {
|
func TestSelectValues(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user