Add initial database/sql support
This commit is contained in:
+171
@@ -0,0 +1,171 @@
|
|||||||
|
package stdlib
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"github.com/JackC/pgx"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
d := &Driver{}
|
||||||
|
sql.Register("pgx", d)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Driver struct{}
|
||||||
|
|
||||||
|
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||||
|
connConfig, err := pgx.ParseURI(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := pgx.Connect(connConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &Conn{conn: conn}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
conn *pgx.Conn
|
||||||
|
psCount int64 // Counter used for creating unique prepared statement names
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
|
||||||
|
name := fmt.Sprintf("pgx_%d", c.psCount)
|
||||||
|
c.psCount++
|
||||||
|
|
||||||
|
ps, err := c.conn.Prepare(name, query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Stmt{ps: ps, conn: c.conn}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Begin() (driver.Tx, error) {
|
||||||
|
_, err := c.conn.Execute("begin")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Tx{conn: c.conn}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Stmt struct {
|
||||||
|
ps *pgx.PreparedStatement
|
||||||
|
conn *pgx.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) Close() error {
|
||||||
|
return s.conn.Deallocate(s.ps.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) NumInput() int {
|
||||||
|
return len(s.ps.ParameterOids)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
||||||
|
args := valueToInterface(argsV)
|
||||||
|
commandTag, err := s.conn.Execute(s.ps.Name, args...)
|
||||||
|
return driver.RowsAffected(commandTag.RowsAffected()), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
||||||
|
args := valueToInterface(argsV)
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
columnsChan := make(chan []string)
|
||||||
|
errChan := make(chan error)
|
||||||
|
rowChan := make(chan []driver.Value)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := s.conn.SelectFunc(s.ps.Name, func(r *pgx.DataRowReader) error {
|
||||||
|
if rowCount == 0 {
|
||||||
|
fieldNames := make([]string, len(r.FieldDescriptions))
|
||||||
|
for i, fd := range r.FieldDescriptions {
|
||||||
|
fieldNames[i] = fd.Name
|
||||||
|
}
|
||||||
|
columnsChan <- fieldNames
|
||||||
|
}
|
||||||
|
rowCount++
|
||||||
|
|
||||||
|
values := make([]driver.Value, len(r.FieldDescriptions))
|
||||||
|
for i, _ := range r.FieldDescriptions {
|
||||||
|
values[i] = r.ReadValue()
|
||||||
|
}
|
||||||
|
rowChan <- values
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}, args...)
|
||||||
|
close(rowChan)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
rows := Rows{rowChan: rowChan}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case rows.columnNames = <-columnsChan:
|
||||||
|
return &rows, nil
|
||||||
|
case err := <-errChan:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Rows struct {
|
||||||
|
columnNames []string
|
||||||
|
rowChan chan []driver.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rows) Columns() []string {
|
||||||
|
return r.columnNames
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rows) Close() error {
|
||||||
|
for _ = range r.rowChan {
|
||||||
|
// Ensure all rows are read
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rows) Next(dest []driver.Value) error {
|
||||||
|
row, ok := <-r.rowChan
|
||||||
|
if !ok {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(dest, row)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func valueToInterface(argsV []driver.Value) []interface{} {
|
||||||
|
args := make([]interface{}, 0, len(argsV))
|
||||||
|
for _, v := range argsV {
|
||||||
|
args = append(args, v.(interface{}))
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tx struct {
|
||||||
|
conn *pgx.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tx) Commit() error {
|
||||||
|
_, err := t.conn.Execute("commit")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tx) Rollback() error {
|
||||||
|
_, err := t.conn.Execute("rollback")
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,233 @@
|
|||||||
|
package stdlib_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
_ "github.com/JackC/pgx/stdlib"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalLifeCycle(t *testing.T) {
|
||||||
|
db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sql.Open failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := db.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
stmt, err := db.Prepare("select 'foo', n from generate_series($1::int, $2::int) n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Prepare unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = stmt.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stmt.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
rows, err := stmt.Query(int32(1), int32(10))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowCount := int64(0)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
rowCount++
|
||||||
|
|
||||||
|
var s string
|
||||||
|
var n int64
|
||||||
|
if err := rows.Scan(&s, &n); err != nil {
|
||||||
|
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if s != "foo" {
|
||||||
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
||||||
|
}
|
||||||
|
if n != rowCount {
|
||||||
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = rows.Err()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||||
|
}
|
||||||
|
if rowCount != 10 {
|
||||||
|
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryCloseRowsEarly(t *testing.T) {
|
||||||
|
db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sql.Open failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := db.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
stmt, err := db.Prepare("select 'foo', n from generate_series($1::int, $2::int) n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Prepare unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = stmt.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stmt.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
rows, err := stmt.Query(int32(1), int32(10))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close rows immediately without having read them
|
||||||
|
err = rows.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the query again to ensure the connection and statement are still ok
|
||||||
|
rows, err = stmt.Query(int32(1), int32(10))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowCount := int64(0)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
rowCount++
|
||||||
|
|
||||||
|
var s string
|
||||||
|
var n int64
|
||||||
|
if err := rows.Scan(&s, &n); err != nil {
|
||||||
|
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if s != "foo" {
|
||||||
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
||||||
|
}
|
||||||
|
if n != rowCount {
|
||||||
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = rows.Err()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||||
|
}
|
||||||
|
if rowCount != 10 {
|
||||||
|
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExec(t *testing.T) {
|
||||||
|
db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sql.Open failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := db.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = db.Exec("create temporary table t(a varchar not null)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := db.Exec("insert into t values('hey')")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("result.RowsAffected unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Fatalf("Expected 1, received %d", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionLifeCycle(t *testing.T) {
|
||||||
|
db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sql.Open failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := db.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = db.Exec("create temporary table t(a varchar not null)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Exec unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Begin unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec("insert into t values('hi')")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tx.Exec unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Rollback()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tx.Rollback unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Fatalf("Expected 0 rows due to rollback, instead found %d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err = db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Begin unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec("insert into t values('hi')")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tx.Exec unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tx.Commit unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.QueryRow("select count(*) from t").Scan(&n)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Fatalf("Expected 1 rows due to rollback, instead found %d", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user