2
0

Add database/sql support to pgtype

This commit is contained in:
Jack Christensen
2017-03-18 21:11:43 -05:00
parent 5572c002dc
commit bec9bd261b
55 changed files with 1459 additions and 201 deletions
+51 -15
View File
@@ -68,14 +68,17 @@ func init() {
databaseSqlOids = make(map[pgtype.Oid]bool)
databaseSqlOids[pgtype.BoolOid] = true
databaseSqlOids[pgtype.ByteaOid] = true
databaseSqlOids[pgtype.CidOid] = true
databaseSqlOids[pgtype.DateOid] = true
databaseSqlOids[pgtype.Float4Oid] = true
databaseSqlOids[pgtype.Float8Oid] = true
databaseSqlOids[pgtype.Int2Oid] = true
databaseSqlOids[pgtype.Int4Oid] = true
databaseSqlOids[pgtype.Int8Oid] = true
databaseSqlOids[pgtype.Float4Oid] = true
databaseSqlOids[pgtype.Float8Oid] = true
databaseSqlOids[pgtype.DateOid] = true
databaseSqlOids[pgtype.TimestamptzOid] = true
databaseSqlOids[pgtype.OidOid] = true
databaseSqlOids[pgtype.TimestampOid] = true
databaseSqlOids[pgtype.TimestamptzOid] = true
databaseSqlOids[pgtype.XidOid] = true
}
type Driver struct {
@@ -292,9 +295,9 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
return s.conn.queryPrepared(s.ps.Name, argsV)
}
// TODO - rename to avoid alloc
type Rows struct {
rows *pgx.Rows
rows *pgx.Rows
values []interface{}
}
func (r *Rows) Columns() []string {
@@ -312,6 +315,42 @@ func (r *Rows) Close() error {
}
func (r *Rows) Next(dest []driver.Value) error {
if r.values == nil {
r.values = make([]interface{}, len(r.rows.FieldDescriptions()))
for i, fd := range r.rows.FieldDescriptions() {
switch fd.DataType {
case pgtype.BoolOid:
r.values[i] = &pgtype.Bool{}
case pgtype.ByteaOid:
r.values[i] = &pgtype.Bytea{}
case pgtype.CidOid:
r.values[i] = &pgtype.Cid{}
case pgtype.DateOid:
r.values[i] = &pgtype.Date{}
case pgtype.Float4Oid:
r.values[i] = &pgtype.Float4{}
case pgtype.Float8Oid:
r.values[i] = &pgtype.Float8{}
case pgtype.Int2Oid:
r.values[i] = &pgtype.Int2{}
case pgtype.Int4Oid:
r.values[i] = &pgtype.Int4{}
case pgtype.Int8Oid:
r.values[i] = &pgtype.Int8{}
case pgtype.OidOid:
r.values[i] = &pgtype.OidValue{}
case pgtype.TimestampOid:
r.values[i] = &pgtype.Timestamp{}
case pgtype.TimestamptzOid:
r.values[i] = &pgtype.Timestamptz{}
case pgtype.XidOid:
r.values[i] = &pgtype.Xid{}
default:
r.values[i] = &pgtype.GenericText{}
}
}
}
more := r.rows.Next()
if !more {
if r.rows.Err() == nil {
@@ -321,19 +360,16 @@ func (r *Rows) Next(dest []driver.Value) error {
}
}
values, err := r.rows.ValuesForStdlib()
err := r.rows.Scan(r.values...)
if err != nil {
return err
}
if len(dest) < len(values) {
fmt.Printf("%d: %#v\n", len(dest), dest)
fmt.Printf("%d: %#v\n", len(values), values)
return errors.New("expected more values than were received")
}
for i, v := range values {
dest[i] = driver.Value(v)
for i, v := range r.values {
dest[i], err = v.(driver.Valuer).Value()
if err != nil {
return err
}
}
return nil