diff --git a/query_test.go b/query_test.go index 1e075fa6..11c5d910 100644 --- a/query_test.go +++ b/query_test.go @@ -1367,7 +1367,7 @@ func TestQueryCloseBefore(t *testing.T) { assert.True(t, pgconn.SafeToRetry(err)) } -func TestRowsFromResultReader(t *testing.T) { +func TestScanRow(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -1377,26 +1377,19 @@ func TestRowsFromResultReader(t *testing.T) { var sum, rowCount int32 - rows := pgx.RowsFromResultReader(conn.ConnInfo(), resultReader) - defer rows.Close() - - for rows.Next() { + for resultReader.NextRow() { var n int32 - rows.Scan(&n) + err := pgx.ScanRow(conn.ConnInfo(), resultReader.FieldDescriptions(), resultReader.Values(), &n) + assert.NoError(t, err) sum += n rowCount++ } - if rows.Err() != nil { - t.Fatalf("conn.Query failed: %v", rows.Err()) - } + _, err := resultReader.Close() - if rowCount != 10 { - t.Error("wrong number of rows") - } - if sum != 55 { - t.Error("Wrong values returned") - } + require.NoError(t, err) + assert.EqualValues(t, 10, rowCount) + assert.EqualValues(t, 55, sum) } func TestConnSimpleProtocol(t *testing.T) { diff --git a/rows.go b/rows.go index 7389c56b..d7595895 100644 --- a/rows.go +++ b/rows.go @@ -174,27 +174,12 @@ func (rows *connRows) Next() bool { } func (rows *connRows) Scan(dest ...interface{}) error { - if len(rows.FieldDescriptions()) != len(dest) { - err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.FieldDescriptions())) + err := ScanRow(rows.connInfo, rows.FieldDescriptions(), rows.values, dest...) + if err != nil { rows.fatal(err) return err } - for i, d := range dest { - buf := rows.values[i] - fd := &rows.FieldDescriptions()[i] - - if d == nil { - continue - } - - err := rows.connInfo.Scan(fd.DataTypeOID, fd.Format, buf, d) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - return err - } - } - return nil } @@ -254,7 +239,7 @@ func (rows *connRows) Values() ([]interface{}, error) { } func (rows *connRows) RawValues() [][]byte { - return rows.resultReader.Values() + return rows.values } type scanArgError struct { @@ -266,14 +251,31 @@ func (e scanArgError) Error() string { return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err) } -// RowsFromResultReader wraps a *pgconn.ResultReader in a Rows wrapper so a more convenient scanning interface can be -// used. +// ScanRow decodes raw row data into dest. This is a low level function used internally to to implement the Rows +// interface Scan method. It can be used to scan rows read from the lower level pgconn interface. // -// In most cases, the appropriate pgx query methods should be used instead of sending a query with pgconn and reading -// the results with pgx. -func RowsFromResultReader(connInfo *pgtype.ConnInfo, rr *pgconn.ResultReader) Rows { - return &connRows{ - connInfo: connInfo, - resultReader: rr, +// connInfo - OID to Go type mapping. +// fieldDescriptions - OID and format of values +// values - the raw data as returned from the PostgreSQL server +// dest - the destination that values will be decoded into +func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error { + if len(fieldDescriptions) != len(values) { + return errors.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) } + if len(fieldDescriptions) != len(dest) { + return errors.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) + } + + for i, d := range dest { + if d == nil { + continue + } + + err := connInfo.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) + if err != nil { + return scanArgError{col: i, err: err} + } + } + + return nil }