Add database/sql support to pgtype
This commit is contained in:
+51
-15
@@ -68,14 +68,17 @@ func init() {
|
||||
databaseSqlOids = make(map[pgtype.Oid]bool)
|
||||
databaseSqlOids[pgtype.BoolOid] = true
|
||||
databaseSqlOids[pgtype.ByteaOid] = true
|
||||
databaseSqlOids[pgtype.CidOid] = true
|
||||
databaseSqlOids[pgtype.DateOid] = true
|
||||
databaseSqlOids[pgtype.Float4Oid] = true
|
||||
databaseSqlOids[pgtype.Float8Oid] = true
|
||||
databaseSqlOids[pgtype.Int2Oid] = true
|
||||
databaseSqlOids[pgtype.Int4Oid] = true
|
||||
databaseSqlOids[pgtype.Int8Oid] = true
|
||||
databaseSqlOids[pgtype.Float4Oid] = true
|
||||
databaseSqlOids[pgtype.Float8Oid] = true
|
||||
databaseSqlOids[pgtype.DateOid] = true
|
||||
databaseSqlOids[pgtype.TimestamptzOid] = true
|
||||
databaseSqlOids[pgtype.OidOid] = true
|
||||
databaseSqlOids[pgtype.TimestampOid] = true
|
||||
databaseSqlOids[pgtype.TimestamptzOid] = true
|
||||
databaseSqlOids[pgtype.XidOid] = true
|
||||
}
|
||||
|
||||
type Driver struct {
|
||||
@@ -292,9 +295,9 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
||||
return s.conn.queryPrepared(s.ps.Name, argsV)
|
||||
}
|
||||
|
||||
// TODO - rename to avoid alloc
|
||||
type Rows struct {
|
||||
rows *pgx.Rows
|
||||
rows *pgx.Rows
|
||||
values []interface{}
|
||||
}
|
||||
|
||||
func (r *Rows) Columns() []string {
|
||||
@@ -312,6 +315,42 @@ func (r *Rows) Close() error {
|
||||
}
|
||||
|
||||
func (r *Rows) Next(dest []driver.Value) error {
|
||||
if r.values == nil {
|
||||
r.values = make([]interface{}, len(r.rows.FieldDescriptions()))
|
||||
for i, fd := range r.rows.FieldDescriptions() {
|
||||
switch fd.DataType {
|
||||
case pgtype.BoolOid:
|
||||
r.values[i] = &pgtype.Bool{}
|
||||
case pgtype.ByteaOid:
|
||||
r.values[i] = &pgtype.Bytea{}
|
||||
case pgtype.CidOid:
|
||||
r.values[i] = &pgtype.Cid{}
|
||||
case pgtype.DateOid:
|
||||
r.values[i] = &pgtype.Date{}
|
||||
case pgtype.Float4Oid:
|
||||
r.values[i] = &pgtype.Float4{}
|
||||
case pgtype.Float8Oid:
|
||||
r.values[i] = &pgtype.Float8{}
|
||||
case pgtype.Int2Oid:
|
||||
r.values[i] = &pgtype.Int2{}
|
||||
case pgtype.Int4Oid:
|
||||
r.values[i] = &pgtype.Int4{}
|
||||
case pgtype.Int8Oid:
|
||||
r.values[i] = &pgtype.Int8{}
|
||||
case pgtype.OidOid:
|
||||
r.values[i] = &pgtype.OidValue{}
|
||||
case pgtype.TimestampOid:
|
||||
r.values[i] = &pgtype.Timestamp{}
|
||||
case pgtype.TimestamptzOid:
|
||||
r.values[i] = &pgtype.Timestamptz{}
|
||||
case pgtype.XidOid:
|
||||
r.values[i] = &pgtype.Xid{}
|
||||
default:
|
||||
r.values[i] = &pgtype.GenericText{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
more := r.rows.Next()
|
||||
if !more {
|
||||
if r.rows.Err() == nil {
|
||||
@@ -321,19 +360,16 @@ func (r *Rows) Next(dest []driver.Value) error {
|
||||
}
|
||||
}
|
||||
|
||||
values, err := r.rows.ValuesForStdlib()
|
||||
err := r.rows.Scan(r.values...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(dest) < len(values) {
|
||||
fmt.Printf("%d: %#v\n", len(dest), dest)
|
||||
fmt.Printf("%d: %#v\n", len(values), values)
|
||||
return errors.New("expected more values than were received")
|
||||
}
|
||||
|
||||
for i, v := range values {
|
||||
dest[i] = driver.Value(v)
|
||||
for i, v := range r.values {
|
||||
dest[i], err = v.(driver.Valuer).Value()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user