Replace RowsFromResultReader with ScanRow function
This commit is contained in:
+8
-15
@@ -1367,7 +1367,7 @@ func TestQueryCloseBefore(t *testing.T) {
|
|||||||
assert.True(t, pgconn.SafeToRetry(err))
|
assert.True(t, pgconn.SafeToRetry(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRowsFromResultReader(t *testing.T) {
|
func TestScanRow(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
@@ -1377,26 +1377,19 @@ func TestRowsFromResultReader(t *testing.T) {
|
|||||||
|
|
||||||
var sum, rowCount int32
|
var sum, rowCount int32
|
||||||
|
|
||||||
rows := pgx.RowsFromResultReader(conn.ConnInfo(), resultReader)
|
for resultReader.NextRow() {
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
var n int32
|
var n int32
|
||||||
rows.Scan(&n)
|
err := pgx.ScanRow(conn.ConnInfo(), resultReader.FieldDescriptions(), resultReader.Values(), &n)
|
||||||
|
assert.NoError(t, err)
|
||||||
sum += n
|
sum += n
|
||||||
rowCount++
|
rowCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.Err() != nil {
|
_, err := resultReader.Close()
|
||||||
t.Fatalf("conn.Query failed: %v", rows.Err())
|
|
||||||
}
|
|
||||||
|
|
||||||
if rowCount != 10 {
|
require.NoError(t, err)
|
||||||
t.Error("wrong number of rows")
|
assert.EqualValues(t, 10, rowCount)
|
||||||
}
|
assert.EqualValues(t, 55, sum)
|
||||||
if sum != 55 {
|
|
||||||
t.Error("Wrong values returned")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnSimpleProtocol(t *testing.T) {
|
func TestConnSimpleProtocol(t *testing.T) {
|
||||||
|
|||||||
@@ -174,27 +174,12 @@ func (rows *connRows) Next() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rows *connRows) Scan(dest ...interface{}) error {
|
func (rows *connRows) Scan(dest ...interface{}) error {
|
||||||
if len(rows.FieldDescriptions()) != len(dest) {
|
err := ScanRow(rows.connInfo, rows.FieldDescriptions(), rows.values, dest...)
|
||||||
err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.FieldDescriptions()))
|
if err != nil {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, d := range dest {
|
|
||||||
buf := rows.values[i]
|
|
||||||
fd := &rows.FieldDescriptions()[i]
|
|
||||||
|
|
||||||
if d == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := rows.connInfo.Scan(fd.DataTypeOID, fd.Format, buf, d)
|
|
||||||
if err != nil {
|
|
||||||
rows.fatal(scanArgError{col: i, err: err})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,7 +239,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rows *connRows) RawValues() [][]byte {
|
func (rows *connRows) RawValues() [][]byte {
|
||||||
return rows.resultReader.Values()
|
return rows.values
|
||||||
}
|
}
|
||||||
|
|
||||||
type scanArgError struct {
|
type scanArgError struct {
|
||||||
@@ -266,14 +251,31 @@ func (e scanArgError) Error() string {
|
|||||||
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
|
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RowsFromResultReader wraps a *pgconn.ResultReader in a Rows wrapper so a more convenient scanning interface can be
|
// ScanRow decodes raw row data into dest. This is a low level function used internally to to implement the Rows
|
||||||
// used.
|
// interface Scan method. It can be used to scan rows read from the lower level pgconn interface.
|
||||||
//
|
//
|
||||||
// In most cases, the appropriate pgx query methods should be used instead of sending a query with pgconn and reading
|
// connInfo - OID to Go type mapping.
|
||||||
// the results with pgx.
|
// fieldDescriptions - OID and format of values
|
||||||
func RowsFromResultReader(connInfo *pgtype.ConnInfo, rr *pgconn.ResultReader) Rows {
|
// values - the raw data as returned from the PostgreSQL server
|
||||||
return &connRows{
|
// dest - the destination that values will be decoded into
|
||||||
connInfo: connInfo,
|
func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error {
|
||||||
resultReader: rr,
|
if len(fieldDescriptions) != len(values) {
|
||||||
|
return errors.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
|
||||||
}
|
}
|
||||||
|
if len(fieldDescriptions) != len(dest) {
|
||||||
|
return errors.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, d := range dest {
|
||||||
|
if d == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := connInfo.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
|
||||||
|
if err != nil {
|
||||||
|
return scanArgError{col: i, err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user