Add RowsFromResultReader
This commit is contained in:
@@ -67,10 +67,15 @@ func (r *connRow) Scan(dest ...interface{}) (err error) {
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
type rowLog interface {
|
||||
shouldLog(lvl LogLevel) bool
|
||||
log(lvl LogLevel, msg string, data map[string]interface{})
|
||||
}
|
||||
|
||||
// connRows implements the Rows interface for Conn.Query.
|
||||
type connRows struct {
|
||||
conn *Conn
|
||||
batch *Batch
|
||||
logger rowLog
|
||||
connInfo *pgtype.ConnInfo
|
||||
values [][]byte
|
||||
fields []FieldDescription
|
||||
rowCount int
|
||||
@@ -81,8 +86,7 @@ type connRows struct {
|
||||
args []interface{}
|
||||
closed bool
|
||||
|
||||
resultReader *pgconn.ResultReader
|
||||
multiResultReader *pgconn.MultiResultReader
|
||||
resultReader *pgconn.ResultReader
|
||||
}
|
||||
|
||||
func (rows *connRows) FieldDescriptions() []FieldDescription {
|
||||
@@ -103,25 +107,16 @@ func (rows *connRows) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
if rows.multiResultReader != nil {
|
||||
closeErr := rows.multiResultReader.Close()
|
||||
if rows.logger != nil {
|
||||
if rows.err == nil {
|
||||
rows.err = closeErr
|
||||
if rows.logger.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
rows.logger.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
|
||||
}
|
||||
} else if rows.logger.shouldLog(LogLevelError) {
|
||||
rows.logger.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
|
||||
}
|
||||
}
|
||||
|
||||
if rows.err == nil {
|
||||
if rows.conn.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
|
||||
}
|
||||
} else if rows.conn.shouldLog(LogLevelError) {
|
||||
rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
|
||||
}
|
||||
|
||||
if rows.batch != nil && rows.err != nil {
|
||||
rows.batch.die(rows.err)
|
||||
}
|
||||
}
|
||||
|
||||
func (rows *connRows) Err() error {
|
||||
@@ -149,7 +144,7 @@ func (rows *connRows) Next() bool {
|
||||
rrFieldDescriptions := rows.resultReader.FieldDescriptions()
|
||||
rows.fields = make([]FieldDescription, len(rrFieldDescriptions))
|
||||
for i := range rrFieldDescriptions {
|
||||
rows.conn.pgproto3FieldDescriptionToPgxFieldDescription(&rrFieldDescriptions[i], &rows.fields[i])
|
||||
pgproto3FieldDescriptionToPgxFieldDescription(rows.connInfo, &rrFieldDescriptions[i], &rows.fields[i])
|
||||
}
|
||||
}
|
||||
rows.rowCount++
|
||||
@@ -191,7 +186,7 @@ func (rows *connRows) Scan(dest ...interface{}) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := rows.conn.ConnInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
|
||||
err := rows.connInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
|
||||
if err != nil {
|
||||
rows.fatal(scanArgError{col: i, err: err})
|
||||
return err
|
||||
@@ -216,7 +211,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok {
|
||||
if dt, ok := rows.connInfo.DataTypeForOID(fd.DataType); ok {
|
||||
value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value)
|
||||
|
||||
switch fd.FormatCode {
|
||||
@@ -225,7 +220,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
||||
if decoder == nil {
|
||||
decoder = &pgtype.GenericText{}
|
||||
}
|
||||
err := decoder.DecodeText(rows.conn.ConnInfo, buf)
|
||||
err := decoder.DecodeText(rows.connInfo, buf)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
}
|
||||
@@ -235,7 +230,7 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
||||
if decoder == nil {
|
||||
decoder = &pgtype.GenericBinary{}
|
||||
}
|
||||
err := decoder.DecodeBinary(rows.conn.ConnInfo, buf)
|
||||
err := decoder.DecodeBinary(rows.connInfo, buf)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
}
|
||||
@@ -263,3 +258,15 @@ type scanArgError struct {
|
||||
func (e scanArgError) Error() string {
|
||||
return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err)
|
||||
}
|
||||
|
||||
// RowsFromResultReader wraps a *pgconn.ResultReader in a Rows wrapper so a more convenient scanning interface can be
|
||||
// used.
|
||||
//
|
||||
// In most cases, the appropriate pgx query methods should be used instead of sending a query with pgconn and reading
|
||||
// the results with pgx.
|
||||
func RowsFromResultReader(connInfo *pgtype.ConnInfo, rr *pgconn.ResultReader) Rows {
|
||||
return &connRows{
|
||||
connInfo: connInfo,
|
||||
resultReader: rr,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user