diff --git a/batch.go b/batch.go index 98f216dd..6fd61295 100644 --- a/batch.go +++ b/batch.go @@ -116,12 +116,12 @@ func (br *batchResults) Query() (Rows, error) { } if br.err != nil { - return &connRows{err: br.err, closed: true}, br.err + return &baseRows{err: br.err, closed: true}, br.err } if br.closed { alreadyClosedErr := fmt.Errorf("batch already closed") - return &connRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr } rows := br.conn.getRows(br.ctx, query, arguments) @@ -182,7 +182,7 @@ func (br *batchResults) QueryFunc(scans []any, f func(QueryFuncRow) error) (pgco // QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. func (br *batchResults) QueryRow() Row { rows, _ := br.Query() - return (*connRow)(rows.(*connRows)) + return (*connRow)(rows.(*baseRows)) } diff --git a/conn.go b/conn.go index ec029ace..ba2ba578 100644 --- a/conn.go +++ b/conn.go @@ -75,7 +75,7 @@ type Conn struct { typeMap *pgtype.Map wbuf []byte - eqb extendedQueryBuilder + eqb ExtendedQueryBuilder } // Identifier a PostgreSQL identifier or name. Identifiers can be composed of @@ -485,49 +485,25 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []a return commandTag, err } -func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, args []any) error { - if len(sd.ParamOIDs) != len(args) { - return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)) - } - - c.eqb.Reset() - - anynil.NormalizeSlice(args) - - for i := range args { - err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) - if err != nil { - err = fmt.Errorf("failed to encode args[%d]: %v", i, err) - return err - } - } - - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) - } - - return nil -} - func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { - err := c.execParamsAndPreparedPrefix(sd, arguments) + err := c.eqb.Build(c.typeMap, sd, arguments) if err != nil { return pgconn.CommandTag{}, err } - result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) { - err := c.execParamsAndPreparedPrefix(sd, arguments) + err := c.eqb.Build(c.typeMap, sd, arguments) if err != nil { return pgconn.CommandTag{}, err } - result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } @@ -540,79 +516,18 @@ func (e *unknownArgumentTypeQueryExecModeExecError) Error() string { } func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) { - c.eqb.Reset() - - anynil.NormalizeSlice(args) - err := c.appendParamsForQueryExecModeExec(args) + err := c.eqb.Build(c.typeMap, nil, args) if err != nil { return pgconn.CommandTag{}, err } - result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats).Read() - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + result := c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats).Read() + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } -// appendParamsForQueryExecModeExec appends the args to c.eqb. -// -// Parameters must be encoded in the text format because of differences in type conversion between timestamps and -// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the -// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both -// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL -// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date. -// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion -// before converting it to date. This means that dates can be shifted by one day. In text format without that double -// type conversion it takes the date directly and ignores time zone (i.e. it works). -// -// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is -// no way to safely use binary or to specify the parameter OIDs. -func (c *Conn) appendParamsForQueryExecModeExec(args []any) error { - for _, arg := range args { - if arg == nil { - err := c.eqb.AppendParamFormat(c.typeMap, 0, TextFormatCode, arg) - if err != nil { - return err - } - } else { - dt, ok := c.TypeMap().TypeForValue(arg) - if !ok { - var tv pgtype.TextValuer - if tv, ok = arg.(pgtype.TextValuer); ok { - t, err := tv.TextValue() - if err != nil { - return err - } - - dt, ok = c.TypeMap().TypeForOID(pgtype.TextOID) - if ok { - arg = t - } - } - } - if !ok { - var str fmt.Stringer - if str, ok = arg.(fmt.Stringer); ok { - dt, ok = c.TypeMap().TypeForOID(pgtype.TextOID) - if ok { - arg = str.String() - } - } - } - if !ok { - return &unknownArgumentTypeQueryExecModeExecError{arg: arg} - } - err := c.eqb.AppendParamFormat(c.typeMap, dt.OID, TextFormatCode, arg) - if err != nil { - return err - } - } - } - - return nil -} - -func (c *Conn) getRows(ctx context.Context, sql string, args []any) *connRows { - r := &connRows{} +func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows { + r := &baseRows{} r.ctx = ctx r.logger = c @@ -735,7 +650,7 @@ optionLoop: sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) } - c.eqb.Reset() + c.eqb.reset() anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) @@ -782,13 +697,10 @@ optionLoop: rows.sql = sd.SQL - for i := range args { - err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) - if err != nil { - err = fmt.Errorf("failed to encode args[%d]: %v", i, err) - rows.fatal(err) - return rows, rows.err - } + err = c.eqb.Build(c.typeMap, sd, args) + if err != nil { + rows.fatal(err) + return rows, rows.err } if resultFormatsByOID != nil { @@ -799,26 +711,22 @@ optionLoop: } if resultFormats == nil { - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) - } - - resultFormats = c.eqb.resultFormats + resultFormats = c.eqb.ResultFormats } if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe { - rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats) } else { - rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats) } } else if mode == QueryExecModeExec { - err := c.appendParamsForQueryExecModeExec(args) + err := c.eqb.Build(c.typeMap, nil, args) if err != nil { rows.fatal(err) return rows, rows.err } - rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats) + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) } else if mode == QueryExecModeSimpleProtocol { sql, err = c.sanitizeForSimpleQuery(sql, args...) if err != nil { @@ -843,7 +751,7 @@ optionLoop: return rows, rows.err } - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return rows, rows.err } @@ -853,7 +761,7 @@ optionLoop: // error with ErrNoRows if no rows are returned. func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := c.Query(ctx, sql, args...) - return (*connRow)(rows.(*connRows)) + return (*connRow)(rows.(*baseRows)) } // QueryFuncRow is the argument to the QueryFunc callback function. @@ -954,34 +862,23 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { if mode == QueryExecModeExec { for _, bi := range b.items { - c.eqb.Reset() + c.eqb.reset() anynil.NormalizeSlice(bi.arguments) sd := c.preparedStatements[bi.query] if sd != nil { - if len(sd.ParamOIDs) != len(bi.arguments) { - return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} - } - - for i := range bi.arguments { - err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) - if err != nil { - err = fmt.Errorf("failed to encode args[%d]: %v", i, err) - return &batchResults{ctx: ctx, conn: c, err: err} - } - } - - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) - } - - batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) - } else { - err := c.appendParamsForQueryExecModeExec(bi.arguments) + err := c.eqb.Build(c.typeMap, sd, bi.arguments) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } - batch.ExecParams(bi.query, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats) + + batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + err := c.eqb.Build(c.typeMap, nil, bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) } } } else { @@ -1014,7 +911,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } for _, bi := range b.items { - c.eqb.Reset() + c.eqb.reset() sd := c.preparedStatements[bi.query] if sd == nil { @@ -1029,29 +926,20 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} } - anynil.NormalizeSlice(bi.arguments) - - for i := range bi.arguments { - err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) - if err != nil { - err = fmt.Errorf("failed to encode args[%d]: %v", i, err) - return &batchResults{ctx: ctx, conn: c, err: err} - } - } - - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + err := c.eqb.Build(c.typeMap, sd, bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} } if sd.Name == "" { - batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) + batch.ExecParams(bi.query, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) } else { - batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) } } } - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. mrr := c.pgConn.ExecBatch(ctx, batch) diff --git a/extended_query_builder.go b/extended_query_builder.go index e69d0b36..1c47063c 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -1,62 +1,98 @@ package pgx import ( + "fmt" + "github.com/jackc/pgx/v5/internal/anynil" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" ) -type extendedQueryBuilder struct { - paramValues [][]byte +// ExtendedQueryBuilder is used to choose the parameter formats, to format the parameters and to choose the result +// formats for an extended query. +type ExtendedQueryBuilder struct { + ParamValues [][]byte paramValueBytes []byte - paramFormats []int16 - resultFormats []int16 + ParamFormats []int16 + ResultFormats []int16 } -func (eqb *extendedQueryBuilder) AppendParam(m *pgtype.Map, oid uint32, arg any) error { - f := eqb.chooseParameterFormatCode(m, oid, arg) - return eqb.AppendParamFormat(m, oid, f, arg) +// Build sets ParamValues, ParamFormats, and ResultFormats for use with *PgConn.ExecParams or *PgConn.ExecPrepared. If +// sd is nil then QueryExecModeExec behavior will be used. +func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error { + eqb.reset() + + anynil.NormalizeSlice(args) + + if sd == nil { + return eqb.appendParamsForQueryExecModeExec(m, args) + } + + if len(sd.ParamOIDs) != len(args) { + return fmt.Errorf("mismatched param and argument count") + } + + for i := range args { + err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i]) + if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %v", i, err) + return err + } + } + + for i := range sd.Fields { + eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + return nil } -func (eqb *extendedQueryBuilder) AppendParamFormat(m *pgtype.Map, oid uint32, format int16, arg any) error { - eqb.paramFormats = append(eqb.paramFormats, format) +// appendParam appends a parameter to the query. format may be -1 to automatically choose the format. If arg is nil it +// must be an untyped nil. +func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error { + if format == -1 { + format = eqb.chooseParameterFormatCode(m, oid, arg) + } + eqb.ParamFormats = append(eqb.ParamFormats, format) v, err := eqb.encodeExtendedParamValue(m, oid, format, arg) if err != nil { return err } - eqb.paramValues = append(eqb.paramValues, v) + eqb.ParamValues = append(eqb.ParamValues, v) return nil } -func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) { - eqb.resultFormats = append(eqb.resultFormats, f) +// appendResultFormat appends a result format to the query. +func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) { + eqb.ResultFormats = append(eqb.ResultFormats, format) } -// Reset readies eqb to build another query. -func (eqb *extendedQueryBuilder) Reset() { - eqb.paramValues = eqb.paramValues[0:0] +// reset readies eqb to build another query. +func (eqb *ExtendedQueryBuilder) reset() { + eqb.ParamValues = eqb.ParamValues[0:0] eqb.paramValueBytes = eqb.paramValueBytes[0:0] - eqb.paramFormats = eqb.paramFormats[0:0] - eqb.resultFormats = eqb.resultFormats[0:0] + eqb.ParamFormats = eqb.ParamFormats[0:0] + eqb.ResultFormats = eqb.ResultFormats[0:0] - if cap(eqb.paramValues) > 64 { - eqb.paramValues = make([][]byte, 0, 64) + if cap(eqb.ParamValues) > 64 { + eqb.ParamValues = make([][]byte, 0, 64) } if cap(eqb.paramValueBytes) > 256 { eqb.paramValueBytes = make([]byte, 0, 256) } - if cap(eqb.paramFormats) > 64 { - eqb.paramFormats = make([]int16, 0, 64) + if cap(eqb.ParamFormats) > 64 { + eqb.ParamFormats = make([]int16, 0, 64) } - if cap(eqb.resultFormats) > 64 { - eqb.resultFormats = make([]int16, 0, 64) + if cap(eqb.ResultFormats) > 64 { + eqb.ResultFormats = make([]int16, 0, 64) } } -func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { +func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { if anynil.Is(arg) { return nil, nil } @@ -81,7 +117,7 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uin // chooseParameterFormatCode determines the correct format code for an // argument to a prepared statement. It defaults to TextFormatCode if no // determination can be made. -func (eqb *extendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 { +func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 { switch arg.(type) { case string, *string: return TextFormatCode @@ -89,3 +125,61 @@ func (eqb *extendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui return m.FormatCodeForOID(oid) } + +// appendParamsForQueryExecModeExec appends the args to eqb. +// +// Parameters must be encoded in the text format because of differences in type conversion between timestamps and +// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the +// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both +// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL +// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date. +// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion +// before converting it to date. This means that dates can be shifted by one day. In text format without that double +// type conversion it takes the date directly and ignores time zone (i.e. it works). +// +// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is +// no way to safely use binary or to specify the parameter OIDs. +func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error { + for _, arg := range args { + if arg == nil { + err := eqb.appendParam(m, 0, TextFormatCode, arg) + if err != nil { + return err + } + } else { + dt, ok := m.TypeForValue(arg) + if !ok { + var tv pgtype.TextValuer + if tv, ok = arg.(pgtype.TextValuer); ok { + t, err := tv.TextValue() + if err != nil { + return err + } + + dt, ok = m.TypeForOID(pgtype.TextOID) + if ok { + arg = t + } + } + } + if !ok { + var str fmt.Stringer + if str, ok = arg.(fmt.Stringer); ok { + dt, ok = m.TypeForOID(pgtype.TextOID) + if ok { + arg = str.String() + } + } + } + if !ok { + return &unknownArgumentTypeQueryExecModeExecError{arg: arg} + } + err := eqb.appendParam(m, dt.OID, TextFormatCode, arg) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/pipeline_test.go b/pipeline_test.go new file mode 100644 index 00000000..b8590bf9 --- /dev/null +++ b/pipeline_test.go @@ -0,0 +1,79 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" +) + +func TestPipelineWithoutPreparedOrDescribedStatements(t *testing.T) { + t.Parallel() + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pipeline := conn.PgConn().StartPipeline(ctx) + + eqb := pgx.ExtendedQueryBuilder{} + + err := eqb.Build(conn.TypeMap(), nil, []any{1, 2}) + require.NoError(t, err) + pipeline.SendQueryParams(`select $1::bigint + $2::bigint`, eqb.ParamValues, nil, eqb.ParamFormats, eqb.ResultFormats) + + err = eqb.Build(conn.TypeMap(), nil, []any{3, 4, 5}) + require.NoError(t, err) + pipeline.SendQueryParams(`select $1::bigint + $2::bigint + $3::bigint`, eqb.ParamValues, nil, eqb.ParamFormats, eqb.ResultFormats) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.True(t, ok) + rows := pgx.RowsFromResultReader(conn.TypeMap(), rr) + + rowCount := 0 + var n int64 + for rows.Next() { + err = rows.Scan(&n) + require.NoError(t, err) + rowCount++ + } + require.NoError(t, rows.Err()) + require.Equal(t, 1, rowCount) + require.Equal(t, "SELECT 1", rows.CommandTag().String()) + require.EqualValues(t, 3, n) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.True(t, ok) + rows = pgx.RowsFromResultReader(conn.TypeMap(), rr) + + rowCount = 0 + n = 0 + for rows.Next() { + err = rows.Scan(&n) + require.NoError(t, err) + rowCount++ + } + require.NoError(t, rows.Err()) + require.Equal(t, 1, rowCount) + require.Equal(t, "SELECT 1", rows.CommandTag().String()) + require.EqualValues(t, 12, n) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.True(t, ok) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + }) +} diff --git a/rows.go b/rows.go index 4f9c533d..d9c0ba47 100644 --- a/rows.go +++ b/rows.go @@ -76,10 +76,10 @@ type RowScanner interface { } // connRow implements the Row interface for Conn.QueryRow. -type connRow connRows +type connRow baseRows func (r *connRow) Scan(dest ...any) (err error) { - rows := (*connRows)(r) + rows := (*baseRows)(r) if rows.Err() != nil { return rows.Err() @@ -109,33 +109,36 @@ type rowLog interface { log(ctx context.Context, lvl LogLevel, msg string, data map[string]any) } -// connRows implements the Rows interface for Conn.Query. -type connRows struct { - ctx context.Context - logger rowLog - typeMap *pgtype.Map - values [][]byte - rowCount int - err error - commandTag pgconn.CommandTag - startTime time.Time - sql string - args []any - closed bool - conn *Conn +// baseRows implements the Rows interface for Conn.Query. +type baseRows struct { + typeMap *pgtype.Map + resultReader *pgconn.ResultReader - resultReader *pgconn.ResultReader - multiResultReader *pgconn.MultiResultReader + values [][]byte + + commandTag pgconn.CommandTag + err error + closed bool scanPlans []pgtype.ScanPlan scanTypes []reflect.Type + + conn *Conn + multiResultReader *pgconn.MultiResultReader + + logger rowLog + ctx context.Context + startTime time.Time + sql string + args []any + rowCount int } -func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription { +func (rows *baseRows) FieldDescriptions() []pgproto3.FieldDescription { return rows.resultReader.FieldDescriptions() } -func (rows *connRows) Close() { +func (rows *baseRows) Close() { if rows.closed { return } @@ -167,24 +170,25 @@ func (rows *connRows) Close() { if rows.logger.shouldLog(LogLevelError) { rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]any{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)}) } - if rows.err != nil && rows.conn.statementCache != nil { - rows.conn.statementCache.StatementErrored(rows.sql, rows.err) - } } } + + if rows.err != nil && rows.conn != nil && rows.conn.statementCache != nil { + rows.conn.statementCache.StatementErrored(rows.sql, rows.err) + } } -func (rows *connRows) CommandTag() pgconn.CommandTag { +func (rows *baseRows) CommandTag() pgconn.CommandTag { return rows.commandTag } -func (rows *connRows) Err() error { +func (rows *baseRows) Err() error { return rows.err } // fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. -func (rows *connRows) fatal(err error) { +func (rows *baseRows) fatal(err error) { if rows.err != nil { return } @@ -193,7 +197,7 @@ func (rows *connRows) fatal(err error) { rows.Close() } -func (rows *connRows) Next() bool { +func (rows *baseRows) Next() bool { if rows.closed { return false } @@ -208,7 +212,7 @@ func (rows *connRows) Next() bool { } } -func (rows *connRows) Scan(dest ...any) error { +func (rows *baseRows) Scan(dest ...any) error { m := rows.typeMap fieldDescriptions := rows.FieldDescriptions() values := rows.values @@ -261,7 +265,7 @@ func (rows *connRows) Scan(dest ...any) error { return nil } -func (rows *connRows) Values() ([]any, error) { +func (rows *baseRows) Values() ([]any, error) { if rows.closed { return nil, errors.New("rows is closed") } @@ -304,7 +308,7 @@ func (rows *connRows) Values() ([]any, error) { return values, rows.Err() } -func (rows *connRows) RawValues() [][]byte { +func (rows *baseRows) RawValues() [][]byte { return rows.values } @@ -348,3 +352,12 @@ func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgproto3.FieldDescription, return nil } + +// RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used +// to read from the lower level pgconn interface. +func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows { + return &baseRows{ + typeMap: typeMap, + resultReader: resultReader, + } +} diff --git a/tx.go b/tx.go index 7254e3dc..76b1768c 100644 --- a/tx.go +++ b/tx.go @@ -281,7 +281,7 @@ func (tx *dbTx) Query(ctx context.Context, sql string, args ...any) (Rows, error if tx.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed - return &connRows{closed: true, err: err}, err + return &baseRows{closed: true, err: err}, err } return tx.conn.Query(ctx, sql, args...) @@ -290,7 +290,7 @@ func (tx *dbTx) Query(ctx context.Context, sql string, args ...any) (Rows, error // QueryRow delegates to the underlying *Conn func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := tx.Query(ctx, sql, args...) - return (*connRow)(rows.(*connRows)) + return (*connRow)(rows.(*baseRows)) } // QueryFunc delegates to the underlying *Conn. @@ -400,7 +400,7 @@ func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...an if sp.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed - return &connRows{closed: true, err: err}, err + return &baseRows{closed: true, err: err}, err } return sp.tx.Query(ctx, sql, args...) @@ -409,7 +409,7 @@ func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...an // QueryRow delegates to the underlying Tx func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args ...any) Row { rows, _ := sp.Query(ctx, sql, args...) - return (*connRow)(rows.(*connRows)) + return (*connRow)(rows.(*baseRows)) } // QueryFunc delegates to the underlying Tx.