diff --git a/connection_pool_test.go b/connection_pool_test.go index da63651b..bc0a5175 100644 --- a/connection_pool_test.go +++ b/connection_pool_test.go @@ -88,11 +88,12 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { allConnections = acquireAll() for _, c := range allConnections { - n, err := c.SelectInt32("select counter from t") + v, err := c.SelectValue("select counter from t") if err != nil { t.Fatal("Unable to read back execution counter: " + err.Error()) } + n := v.(int32) if n == 0 { t.Error("A connection was never used") } diff --git a/connection_select_value.go b/connection_select_value.go index ebe93099..da832059 100644 --- a/connection_select_value.go +++ b/connection_select_value.go @@ -1,70 +1,69 @@ package pgx import ( - "errors" + "fmt" "strconv" ) -func (c *Connection) SelectString(sql string) (s string, err error) { +func (c *Connection) SelectValue(sql string) (v interface{}, err error) { onDataRow := func(r *DataRowReader) error { - var null bool - s, null = c.rxDataRowFirstValue(r.mr) - if null { - return errors.New("Unexpected NULL") + size := r.mr.ReadInt32() + if size > -1 { + switch r.fields[0].DataType { + case oid(16): // bool + s := r.mr.ReadByteString(size) + switch s { + case "t": + v = true + case "f": + v = false + default: + fmt.Errorf("Received invalid bool: %v", s) + } + case oid(20): // int8 + s := r.mr.ReadByteString(size) + v, err = strconv.ParseInt(s, 10, 64) + if err != nil { + fmt.Errorf("Received invalid int8: %v", s) + } + case oid(21): // int2 + s := r.mr.ReadByteString(size) + var n int64 + n, err = strconv.ParseInt(s, 10, 16) + if err != nil { + fmt.Errorf("Received invalid int2: %v", s) + } + v = int16(n) + case oid(23): // int4 + s := r.mr.ReadByteString(size) + var n int64 + n, err = strconv.ParseInt(s, 10, 32) + if err != nil { + fmt.Errorf("Received invalid int4: %v", s) + } + v = int32(n) + case oid(700): // float4 + s := r.mr.ReadByteString(size) + var n float64 + n, err = strconv.ParseFloat(s, 32) + if err != nil { + fmt.Errorf("Received invalid float4: %v", s) + } + v = float32(n) + case oid(701): //float8 + s := r.mr.ReadByteString(size) + v, err = strconv.ParseFloat(s, 64) + if err != nil { + fmt.Errorf("Received invalid float8: %v", s) + } + default: + v = r.mr.ReadByteString(size) + } + } else { + v = nil } return nil } err = c.SelectFunc(sql, onDataRow) return } - -func (c *Connection) selectInt(sql string, size int) (i int64, err error) { - var s string - s, err = c.SelectString(sql) - if err != nil { - return - } - - i, err = strconv.ParseInt(s, 10, size) - return -} - -func (c *Connection) SelectInt64(sql string) (i int64, err error) { - return c.selectInt(sql, 64) -} - -func (c *Connection) SelectInt32(sql string) (i int32, err error) { - var i64 int64 - i64, err = c.selectInt(sql, 32) - i = int32(i64) - return -} - -func (c *Connection) SelectInt16(sql string) (i int16, err error) { - var i64 int64 - i64, err = c.selectInt(sql, 16) - i = int16(i64) - return -} - -func (c *Connection) selectFloat(sql string, size int) (f float64, err error) { - var s string - s, err = c.SelectString(sql) - if err != nil { - return - } - - f, err = strconv.ParseFloat(s, size) - return -} - -func (c *Connection) SelectFloat64(sql string) (f float64, err error) { - return c.selectFloat(sql, 64) -} - -func (c *Connection) SelectFloat32(sql string) (f float32, err error) { - var f64 float64 - f64, err = c.selectFloat(sql, 32) - f = float32(f64) - return -} diff --git a/connection_select_value_test.go b/connection_select_value_test.go index 5dfc3ee2..5d48351c 100644 --- a/connection_select_value_test.go +++ b/connection_select_value_test.go @@ -1,146 +1,90 @@ package pgx import ( - "strings" "testing" ) -func TestSelectString(t *testing.T) { +func TestSelectValue(t *testing.T) { conn := getSharedConnection() + var v interface{} + var err error - s, err := conn.SelectString("select 'foo'") + v, err = conn.SelectValue("select null") if err != nil { - t.Error("Unable to select string: " + err.Error()) - } else if s != "foo" { - t.Error("Received incorrect string") + t.Errorf("Unable to select null: %v", err) + } else { + if v != nil { + t.Errorf("Expected: nil, recieved: %v", v) + } } - _, err = conn.SelectString("select null") - if err == nil { - t.Error("Should have received error on null") + v, err = conn.SelectValue("select 'foo'") + if err != nil { + t.Errorf("Unable to select string: %v", err) + } else { + s, ok := v.(string) + if !(ok && s == "foo") { + t.Errorf("Expected: foo, recieved: %#v", v) + } + } + + v, err = conn.SelectValue("select true") + if err != nil { + t.Errorf("Unable to select bool: %#v", err) + } else { + s, ok := v.(bool) + if !(ok && s == true) { + t.Errorf("Expected true, received: %#v", v) + } + } + + v, err = conn.SelectValue("select false") + if err != nil { + t.Errorf("Unable to select bool: %v", err) + } else { + s, ok := v.(bool) + if !(ok && s == false) { + t.Errorf("Expected false, received: %#v", v) + } + } + + v, err = conn.SelectValue("select 1::int2") + if err != nil { + t.Errorf("Unable to select int2: %v", err) + } else { + s, ok := v.(int16) + if !(ok && s == 1) { + t.Errorf("Expected 1, received: %#v", v) + } + } + + v, err = conn.SelectValue("select 1::int4") + if err != nil { + t.Errorf("Unable to select int4: %v", err) + } else { + s, ok := v.(int32) + if !(ok && s == 1) { + t.Errorf("Expected 1, received: %#v", v) + } + } + + v, err = conn.SelectValue("select 1::int8") + if err != nil { + t.Errorf("Unable to select int8: %#v", err) + } else { + s, ok := v.(int64) + if !(ok && s == 1) { + t.Errorf("Expected 1, received: %#v", v) + } + } + + v, err = conn.SelectValue("select 1.23::float4") + if err != nil { + t.Errorf("Unable to select float4: %#v", err) + } else { + s, ok := v.(float32) + if !(ok && s == float32(1.23)) { + t.Errorf("Expected 1.23, received: %#v", v) + } } } - - -func TestSelectInt64(t *testing.T) { - conn := getSharedConnection() - - i, err := conn.SelectInt64("select 1") - if err != nil { - t.Fatal("Unable to select int64: " + err.Error()) - } - - if i != 1 { - t.Error("Received incorrect int64") - } - - i, err = conn.SelectInt64("select power(2,65)::numeric") - if err == nil || !strings.Contains(err.Error(), "value out of range") { - t.Error("Expected value out of range error when selecting number greater than max int64") - } - - i, err = conn.SelectInt64("select -power(2,65)::numeric") - if err == nil || !strings.Contains(err.Error(), "value out of range") { - t.Error("Expected value out of range error when selecting number less than min int64") - } - - _, err = conn.SelectInt64("select null") - if err == nil || !strings.Contains(err.Error(), "NULL") { - t.Error("Should have received error on null") - } -} - -func TestSelectInt32(t *testing.T) { - conn := getSharedConnection() - - i, err := conn.SelectInt32("select 1") - if err != nil { - t.Fatal("Unable to select int32: " + err.Error()) - } - - if i != 1 { - t.Error("Received incorrect int32") - } - - i, err = conn.SelectInt32("select power(2,33)::numeric") - if err == nil || !strings.Contains(err.Error(), "value out of range") { - t.Error("Expected value out of range error when selecting number greater than max int32") - } - - i, err = conn.SelectInt32("select -power(2,33)::numeric") - if err == nil || !strings.Contains(err.Error(), "value out of range") { - t.Error("Expected value out of range error when selecting number less than min int32") - } - - _, err = conn.SelectInt32("select null") - if err == nil || !strings.Contains(err.Error(), "NULL") { - t.Error("Should have received error on null") - } -} - -func TestSelectInt16(t *testing.T) { - conn := getSharedConnection() - - i, err := conn.SelectInt16("select 1") - if err != nil { - t.Fatal("Unable to select int16: " + err.Error()) - } - - if i != 1 { - t.Error("Received incorrect int16") - } - - i, err = conn.SelectInt16("select power(2,17)::numeric") - if err == nil || !strings.Contains(err.Error(), "value out of range") { - t.Error("Expected value out of range error when selecting number greater than max int16") - } - - i, err = conn.SelectInt16("select -power(2,17)::numeric") - if err == nil || !strings.Contains(err.Error(), "value out of range") { - t.Error("Expected value out of range error when selecting number less than min int16") - } - - _, err = conn.SelectInt16("select null") - if err == nil || !strings.Contains(err.Error(), "NULL") { - t.Error("Should have received error on null") - } -} - - - -func TestSelectFloat64(t *testing.T) { - conn := getSharedConnection() - - f, err := conn.SelectFloat64("select 1.23") - if err != nil { - t.Fatal("Unable to select float64: " + err.Error()) - } - - if f != 1.23 { - t.Error("Received incorrect float64") - } - - _, err = conn.SelectFloat64("select null") - if err == nil || !strings.Contains(err.Error(), "NULL") { - t.Error("Should have received error on null") - } -} - -func TestSelectFloat32(t *testing.T) { - conn := getSharedConnection() - - f, err := conn.SelectFloat32("select 1.23") - if err != nil { - t.Fatal("Unable to select float32: " + err.Error()) - } - - if f != 1.23 { - t.Error("Received incorrect float32") - } - - _, err = conn.SelectFloat32("select null") - if err == nil || !strings.Contains(err.Error(), "NULL") { - t.Error("Should have received error on null") - } -} - diff --git a/connection_select_value_test.go.erb b/connection_select_value_test.go.erb deleted file mode 100644 index 20f8f1de..00000000 --- a/connection_select_value_test.go.erb +++ /dev/null @@ -1,72 +0,0 @@ -package pgx - -import ( - "strings" - "testing" -) - -func TestSelectString(t *testing.T) { - conn := getSharedConnection() - - s, err := conn.SelectString("select 'foo'") - if err != nil { - t.Error("Unable to select string: " + err.Error()) - } else if s != "foo" { - t.Error("Received incorrect string") - } - - _, err = conn.SelectString("select null") - if err == nil { - t.Error("Should have received error on null") - } -} - -<% [64, 32, 16].each do |size| %> -func TestSelectInt<%= size %>(t *testing.T) { - conn := getSharedConnection() - - i, err := conn.SelectInt<%= size %>("select 1") - if err != nil { - t.Fatal("Unable to select int<%= size %>: " + err.Error()) - } - - if i != 1 { - t.Error("Received incorrect int<%= size %>") - } - - i, err = conn.SelectInt<%= size %>("select power(2,<%= size + 1 %>)::numeric") - if err == nil || !strings.Contains(err.Error(), "value out of range") { - t.Error("Expected value out of range error when selecting number greater than max int<%= size %>") - } - - i, err = conn.SelectInt<%= size %>("select -power(2,<%= size + 1 %>)::numeric") - if err == nil || !strings.Contains(err.Error(), "value out of range") { - t.Error("Expected value out of range error when selecting number less than min int<%= size %>") - } - - _, err = conn.SelectInt<%= size %>("select null") - if err == nil || !strings.Contains(err.Error(), "NULL") { - t.Error("Should have received error on null") - } -} -<% end %> - -<% [64, 32].each do |size| %> -func TestSelectFloat<%= size %>(t *testing.T) { - conn := getSharedConnection() - - f, err := conn.SelectFloat<%= size %>("select 1.23") - if err != nil { - t.Fatal("Unable to select float<%= size %>: " + err.Error()) - } - - if f != 1.23 { - t.Error("Received incorrect float<%= size %>") - } - - _, err = conn.SelectFloat<%= size %>("select null") - if err == nil || !strings.Contains(err.Error(), "NULL") { - t.Error("Should have received error on null") - } -} -<% end %>