Add CollectRows and RowTo* functions
Collect functionality was originally developed in pgxutil
This commit is contained in:
@@ -395,3 +395,124 @@ func ForEachScannedRow(rows Rows, scans []any, fn func() error) (pgconn.CommandT
|
||||
|
||||
return rows.CommandTag(), nil
|
||||
}
|
||||
|
||||
// CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call.
|
||||
type CollectableRow interface {
|
||||
FieldDescriptions() []pgproto3.FieldDescription
|
||||
Scan(dest ...any) error
|
||||
Values() ([]any, error)
|
||||
RawValues() [][]byte
|
||||
}
|
||||
|
||||
// RowToFunc is a function that scans or otherwise converts row to a T.
|
||||
type RowToFunc[T any] func(row CollectableRow) (T, error)
|
||||
|
||||
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
|
||||
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
|
||||
defer rows.Close()
|
||||
|
||||
slice := []T{}
|
||||
|
||||
for rows.Next() {
|
||||
value, err := fn(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slice = append(slice, value)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
// RowTo returns a T scanned from row.
|
||||
func RowTo[T any](row CollectableRow) (T, error) {
|
||||
var value T
|
||||
err := row.Scan(&value)
|
||||
return value, err
|
||||
}
|
||||
|
||||
// RowTo returns a the address of a T scanned from row.
|
||||
func RowToAddrOf[T any](row CollectableRow) (*T, error) {
|
||||
var value T
|
||||
err := row.Scan(&value)
|
||||
return &value, err
|
||||
}
|
||||
|
||||
// RowToMap returns a map scanned from row.
|
||||
func RowToMap(row CollectableRow) (map[string]any, error) {
|
||||
var value map[string]any
|
||||
err := row.Scan((*mapRowScanner)(&value))
|
||||
return value, err
|
||||
}
|
||||
|
||||
type mapRowScanner map[string]any
|
||||
|
||||
func (rs *mapRowScanner) ScanRow(rows Rows) error {
|
||||
values, err := rows.Values()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*rs = make(mapRowScanner, len(values))
|
||||
|
||||
for i := range values {
|
||||
(*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row
|
||||
// has fields. The row and T fields will by matched by position.
|
||||
func RowToStructByPos[T any](row CollectableRow) (T, error) {
|
||||
var value T
|
||||
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
|
||||
return value, err
|
||||
}
|
||||
|
||||
// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a
|
||||
// public fields as row has fields. The row and T fields will by matched by position.
|
||||
func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
|
||||
var value T
|
||||
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
|
||||
return &value, err
|
||||
}
|
||||
|
||||
type positionalStructRowScanner struct {
|
||||
ptrToStruct any
|
||||
}
|
||||
|
||||
func (rs *positionalStructRowScanner) ScanRow(rows Rows) error {
|
||||
dst := rs.ptrToStruct
|
||||
dstValue := reflect.ValueOf(dst)
|
||||
if dstValue.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dst not a pointer")
|
||||
}
|
||||
|
||||
dstElemValue := dstValue.Elem()
|
||||
dstElemType := dstElemValue.Type()
|
||||
|
||||
exportedFields := make([]int, 0, dstElemType.NumField())
|
||||
for i := 0; i < dstElemType.NumField(); i++ {
|
||||
sf := dstElemType.Field(i)
|
||||
if sf.PkgPath == "" {
|
||||
exportedFields = append(exportedFields, i)
|
||||
}
|
||||
}
|
||||
|
||||
rowFieldCount := len(rows.RawValues())
|
||||
if rowFieldCount > len(exportedFields) {
|
||||
return fmt.Errorf("got %d values, but dst struct has only %d fields", rowFieldCount, len(exportedFields))
|
||||
}
|
||||
|
||||
scanTargets := make([]any, rowFieldCount)
|
||||
for i := 0; i < rowFieldCount; i++ {
|
||||
scanTargets[i] = dstElemValue.Field(exportedFields[i]).Addr().Interface()
|
||||
}
|
||||
|
||||
return rows.Scan(scanTargets...)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user