2
0

Replace RowsFromResultReader with ScanRow function

This commit is contained in:
Jack Christensen
2019-09-10 18:34:05 -05:00
parent 76348773bd
commit 7d053e4d5c
2 changed files with 36 additions and 41 deletions
+8 -15
View File
@@ -1367,7 +1367,7 @@ func TestQueryCloseBefore(t *testing.T) {
assert.True(t, pgconn.SafeToRetry(err)) assert.True(t, pgconn.SafeToRetry(err))
} }
func TestRowsFromResultReader(t *testing.T) { func TestScanRow(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -1377,26 +1377,19 @@ func TestRowsFromResultReader(t *testing.T) {
var sum, rowCount int32 var sum, rowCount int32
rows := pgx.RowsFromResultReader(conn.ConnInfo(), resultReader) for resultReader.NextRow() {
defer rows.Close()
for rows.Next() {
var n int32 var n int32
rows.Scan(&n) err := pgx.ScanRow(conn.ConnInfo(), resultReader.FieldDescriptions(), resultReader.Values(), &n)
assert.NoError(t, err)
sum += n sum += n
rowCount++ rowCount++
} }
if rows.Err() != nil { _, err := resultReader.Close()
t.Fatalf("conn.Query failed: %v", rows.Err())
}
if rowCount != 10 { require.NoError(t, err)
t.Error("wrong number of rows") assert.EqualValues(t, 10, rowCount)
} assert.EqualValues(t, 55, sum)
if sum != 55 {
t.Error("Wrong values returned")
}
} }
func TestConnSimpleProtocol(t *testing.T) { func TestConnSimpleProtocol(t *testing.T) {
+28 -26
View File
@@ -174,27 +174,12 @@ func (rows *connRows) Next() bool {
} }
func (rows *connRows) Scan(dest ...interface{}) error { func (rows *connRows) Scan(dest ...interface{}) error {
if len(rows.FieldDescriptions()) != len(dest) { err := ScanRow(rows.connInfo, rows.FieldDescriptions(), rows.values, dest...)
err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.FieldDescriptions())) if err != nil {
rows.fatal(err) rows.fatal(err)
return err return err
} }
for i, d := range dest {
buf := rows.values[i]
fd := &rows.FieldDescriptions()[i]
if d == nil {
continue
}
err := rows.connInfo.Scan(fd.DataTypeOID, fd.Format, buf, d)
if err != nil {
rows.fatal(scanArgError{col: i, err: err})
return err
}
}
return nil return nil
} }
@@ -254,7 +239,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
} }
func (rows *connRows) RawValues() [][]byte { func (rows *connRows) RawValues() [][]byte {
return rows.resultReader.Values() return rows.values
} }
type scanArgError struct { type scanArgError struct {
@@ -266,14 +251,31 @@ func (e scanArgError) Error() string {
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err) return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
} }
// RowsFromResultReader wraps a *pgconn.ResultReader in a Rows wrapper so a more convenient scanning interface can be // ScanRow decodes raw row data into dest. This is a low level function used internally to to implement the Rows
// used. // interface Scan method. It can be used to scan rows read from the lower level pgconn interface.
// //
// In most cases, the appropriate pgx query methods should be used instead of sending a query with pgconn and reading // connInfo - OID to Go type mapping.
// the results with pgx. // fieldDescriptions - OID and format of values
func RowsFromResultReader(connInfo *pgtype.ConnInfo, rr *pgconn.ResultReader) Rows { // values - the raw data as returned from the PostgreSQL server
return &connRows{ // dest - the destination that values will be decoded into
connInfo: connInfo, func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error {
resultReader: rr, if len(fieldDescriptions) != len(values) {
return errors.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
} }
if len(fieldDescriptions) != len(dest) {
return errors.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
}
for i, d := range dest {
if d == nil {
continue
}
err := connInfo.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
if err != nil {
return scanArgError{col: i, err: err}
}
}
return nil
} }