diff --git a/stdlib/sql.go b/stdlib/sql.go new file mode 100644 index 00000000..b276861a --- /dev/null +++ b/stdlib/sql.go @@ -0,0 +1,171 @@ +package stdlib + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "github.com/JackC/pgx" + "io" +) + +func init() { + d := &Driver{} + sql.Register("pgx", d) +} + +type Driver struct{} + +func (d *Driver) Open(name string) (driver.Conn, error) { + connConfig, err := pgx.ParseURI(name) + if err != nil { + return nil, err + } + + conn, err := pgx.Connect(connConfig) + if err != nil { + return nil, err + } + + c := &Conn{conn: conn} + return c, nil +} + +type Conn struct { + conn *pgx.Conn + psCount int64 // Counter used for creating unique prepared statement names +} + +func (c *Conn) Prepare(query string) (driver.Stmt, error) { + name := fmt.Sprintf("pgx_%d", c.psCount) + c.psCount++ + + ps, err := c.conn.Prepare(name, query) + if err != nil { + return nil, err + } + + return &Stmt{ps: ps, conn: c.conn}, nil +} + +func (c *Conn) Close() error { + return c.conn.Close() +} + +func (c *Conn) Begin() (driver.Tx, error) { + _, err := c.conn.Execute("begin") + if err != nil { + return nil, err + } + + return &Tx{conn: c.conn}, nil +} + +type Stmt struct { + ps *pgx.PreparedStatement + conn *pgx.Conn +} + +func (s *Stmt) Close() error { + return s.conn.Deallocate(s.ps.Name) +} + +func (s *Stmt) NumInput() int { + return len(s.ps.ParameterOids) +} + +func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { + args := valueToInterface(argsV) + commandTag, err := s.conn.Execute(s.ps.Name, args...) + return driver.RowsAffected(commandTag.RowsAffected()), err +} + +func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { + args := valueToInterface(argsV) + + rowCount := 0 + columnsChan := make(chan []string) + errChan := make(chan error) + rowChan := make(chan []driver.Value) + + go func() { + err := s.conn.SelectFunc(s.ps.Name, func(r *pgx.DataRowReader) error { + if rowCount == 0 { + fieldNames := make([]string, len(r.FieldDescriptions)) + for i, fd := range r.FieldDescriptions { + fieldNames[i] = fd.Name + } + columnsChan <- fieldNames + } + rowCount++ + + values := make([]driver.Value, len(r.FieldDescriptions)) + for i, _ := range r.FieldDescriptions { + values[i] = r.ReadValue() + } + rowChan <- values + + return nil + }, args...) + close(rowChan) + if err != nil { + errChan <- err + } + }() + + rows := Rows{rowChan: rowChan} + + select { + case rows.columnNames = <-columnsChan: + return &rows, nil + case err := <-errChan: + return nil, err + } +} + +type Rows struct { + columnNames []string + rowChan chan []driver.Value +} + +func (r *Rows) Columns() []string { + return r.columnNames +} + +func (r *Rows) Close() error { + for _ = range r.rowChan { + // Ensure all rows are read + } + return nil +} + +func (r *Rows) Next(dest []driver.Value) error { + row, ok := <-r.rowChan + if !ok { + return io.EOF + } + + copy(dest, row) + return nil +} + +func valueToInterface(argsV []driver.Value) []interface{} { + args := make([]interface{}, 0, len(argsV)) + for _, v := range argsV { + args = append(args, v.(interface{})) + } + return args +} + +type Tx struct { + conn *pgx.Conn +} + +func (t *Tx) Commit() error { + _, err := t.conn.Execute("commit") + return err +} + +func (t *Tx) Rollback() error { + _, err := t.conn.Execute("rollback") + return err +} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go new file mode 100644 index 00000000..81cf32bc --- /dev/null +++ b/stdlib/sql_test.go @@ -0,0 +1,233 @@ +package stdlib_test + +import ( + "database/sql" + _ "github.com/JackC/pgx/stdlib" + "testing" +) + +func TestNormalLifeCycle(t *testing.T) { + db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test") + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer func() { + err := db.Close() + if err != nil { + t.Fatalf("db.Close unexpectedly failed: %v", err) + } + }() + + stmt, err := db.Prepare("select 'foo', n from generate_series($1::int, $2::int) n") + if err != nil { + t.Fatalf("db.Prepare unexpectedly failed: %v", err) + } + defer func() { + err = stmt.Close() + if err != nil { + t.Fatalf("stmt.Close unexpectedly failed: %v", err) + } + }() + + rows, err := stmt.Query(int32(1), int32(10)) + if err != nil { + t.Fatalf("stmt.Query unexpectedly failed: %v", err) + } + + rowCount := int64(0) + + for rows.Next() { + rowCount++ + + var s string + var n int64 + if err := rows.Scan(&s, &n); err != nil { + t.Fatalf("rows.Scan unexpectedly failed: %v", err) + } + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + err = rows.Err() + if err != nil { + t.Fatalf("rows.Err unexpectedly is: %v", err) + } + if rowCount != 10 { + t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) + } + + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } +} + +func TestQueryCloseRowsEarly(t *testing.T) { + db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test") + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer func() { + err := db.Close() + if err != nil { + t.Fatalf("db.Close unexpectedly failed: %v", err) + } + }() + + stmt, err := db.Prepare("select 'foo', n from generate_series($1::int, $2::int) n") + if err != nil { + t.Fatalf("db.Prepare unexpectedly failed: %v", err) + } + defer func() { + err = stmt.Close() + if err != nil { + t.Fatalf("stmt.Close unexpectedly failed: %v", err) + } + }() + + rows, err := stmt.Query(int32(1), int32(10)) + if err != nil { + t.Fatalf("stmt.Query unexpectedly failed: %v", err) + } + + // Close rows immediately without having read them + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } + + // Run the query again to ensure the connection and statement are still ok + rows, err = stmt.Query(int32(1), int32(10)) + if err != nil { + t.Fatalf("stmt.Query unexpectedly failed: %v", err) + } + + rowCount := int64(0) + + for rows.Next() { + rowCount++ + + var s string + var n int64 + if err := rows.Scan(&s, &n); err != nil { + t.Fatalf("rows.Scan unexpectedly failed: %v", err) + } + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + err = rows.Err() + if err != nil { + t.Fatalf("rows.Err unexpectedly is: %v", err) + } + if rowCount != 10 { + t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) + } + + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } +} + +func TestExec(t *testing.T) { + db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test") + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer func() { + err := db.Close() + if err != nil { + t.Fatalf("db.Close unexpectedly failed: %v", err) + } + }() + + _, err = db.Exec("create temporary table t(a varchar not null)") + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + result, err := db.Exec("insert into t values('hey')") + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + n, err := result.RowsAffected() + if err != nil { + t.Fatalf("result.RowsAffected unexpectedly failed: %v", err) + } + if n != 1 { + t.Fatalf("Expected 1, received %d", n) + } +} + +func TestTransactionLifeCycle(t *testing.T) { + db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test") + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer func() { + err := db.Close() + if err != nil { + t.Fatalf("db.Close unexpectedly failed: %v", err) + } + }() + + _, err = db.Exec("create temporary table t(a varchar not null)") + if err != nil { + t.Fatalf("db.Exec unexpectedly failed: %v", err) + } + + tx, err := db.Begin() + if err != nil { + t.Fatalf("db.Begin unexpectedly failed: %v", err) + } + + _, err = tx.Exec("insert into t values('hi')") + if err != nil { + t.Fatalf("tx.Exec unexpectedly failed: %v", err) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback unexpectedly failed: %v", err) + } + + var n int64 + err = db.QueryRow("select count(*) from t").Scan(&n) + if err != nil { + t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err) + } + if n != 0 { + t.Fatalf("Expected 0 rows due to rollback, instead found %d", n) + } + + tx, err = db.Begin() + if err != nil { + t.Fatalf("db.Begin unexpectedly failed: %v", err) + } + + _, err = tx.Exec("insert into t values('hi')") + if err != nil { + t.Fatalf("tx.Exec unexpectedly failed: %v", err) + } + + err = tx.Commit() + if err != nil { + t.Fatalf("tx.Commit unexpectedly failed: %v", err) + } + + err = db.QueryRow("select count(*) from t").Scan(&n) + if err != nil { + t.Fatalf("db.QueryRow.Scan unexpectedly failed: %v", err) + } + if n != 1 { + t.Fatalf("Expected 1 rows due to rollback, instead found %d", n) + } +}