41fa6e844c
Logging previously only logged statement results on Exec() and Query(), but not CopyFrom(). This makes change makes it consistent.
213 lines
5.1 KiB
Go
213 lines
5.1 KiB
Go
package pgx
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
"github.com/jackc/pgconn"
|
|
"github.com/jackc/pgio"
|
|
errors "golang.org/x/xerrors"
|
|
)
|
|
|
|
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
|
|
// making it usable by *Conn.CopyFrom.
|
|
func CopyFromRows(rows [][]interface{}) CopyFromSource {
|
|
return ©FromRows{rows: rows, idx: -1}
|
|
}
|
|
|
|
type copyFromRows struct {
|
|
rows [][]interface{}
|
|
idx int
|
|
}
|
|
|
|
func (ctr *copyFromRows) Next() bool {
|
|
ctr.idx++
|
|
return ctr.idx < len(ctr.rows)
|
|
}
|
|
|
|
func (ctr *copyFromRows) Values() ([]interface{}, error) {
|
|
return ctr.rows[ctr.idx], nil
|
|
}
|
|
|
|
func (ctr *copyFromRows) Err() error {
|
|
return nil
|
|
}
|
|
|
|
// CopyFromSlice returns a CopyFromSource interface over a dynamic func
|
|
// making it usable by *Conn.CopyFrom.
|
|
func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource {
|
|
return ©FromSlice{next: next, idx: -1, len: length}
|
|
}
|
|
|
|
type copyFromSlice struct {
|
|
next func(int) ([]interface{}, error)
|
|
idx int
|
|
len int
|
|
err error
|
|
}
|
|
|
|
func (cts *copyFromSlice) Next() bool {
|
|
cts.idx++
|
|
return cts.idx < cts.len
|
|
}
|
|
|
|
func (cts *copyFromSlice) Values() ([]interface{}, error) {
|
|
values, err := cts.next(cts.idx)
|
|
if err != nil {
|
|
cts.err = err
|
|
}
|
|
return values, err
|
|
}
|
|
|
|
func (cts *copyFromSlice) Err() error {
|
|
return cts.err
|
|
}
|
|
|
|
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
|
|
type CopyFromSource interface {
|
|
// Next returns true if there is another row and makes the next row data
|
|
// available to Values(). When there are no more rows available or an error
|
|
// has occurred it returns false.
|
|
Next() bool
|
|
|
|
// Values returns the values for the current row.
|
|
Values() ([]interface{}, error)
|
|
|
|
// Err returns any error that has been encountered by the CopyFromSource. If
|
|
// this is not nil *Conn.CopyFrom will abort the copy.
|
|
Err() error
|
|
}
|
|
|
|
type copyFrom struct {
|
|
conn *Conn
|
|
tableName Identifier
|
|
columnNames []string
|
|
rowSrc CopyFromSource
|
|
readerErrChan chan error
|
|
}
|
|
|
|
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
|
quotedTableName := ct.tableName.Sanitize()
|
|
cbuf := &bytes.Buffer{}
|
|
for i, cn := range ct.columnNames {
|
|
if i != 0 {
|
|
cbuf.WriteString(", ")
|
|
}
|
|
cbuf.WriteString(quoteIdentifier(cn))
|
|
}
|
|
quotedColumnNames := cbuf.String()
|
|
|
|
sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
r, w := io.Pipe()
|
|
doneChan := make(chan struct{})
|
|
|
|
go func() {
|
|
defer close(doneChan)
|
|
|
|
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
|
|
buf := ct.conn.wbuf
|
|
|
|
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
|
buf = pgio.AppendInt32(buf, 0)
|
|
buf = pgio.AppendInt32(buf, 0)
|
|
|
|
moreRows := true
|
|
for moreRows {
|
|
var err error
|
|
moreRows, buf, err = ct.buildCopyBuf(buf, sd)
|
|
if err != nil {
|
|
w.CloseWithError(err)
|
|
return
|
|
}
|
|
|
|
if ct.rowSrc.Err() != nil {
|
|
w.CloseWithError(ct.rowSrc.Err())
|
|
return
|
|
}
|
|
|
|
if len(buf) > 0 {
|
|
_, err = w.Write(buf)
|
|
if err != nil {
|
|
w.Close()
|
|
return
|
|
}
|
|
}
|
|
|
|
buf = buf[:0]
|
|
}
|
|
|
|
w.Close()
|
|
}()
|
|
|
|
startTime := time.Now()
|
|
|
|
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
|
|
|
r.Close()
|
|
<-doneChan
|
|
|
|
rowsAffected := commandTag.RowsAffected()
|
|
if err == nil {
|
|
if ct.conn.shouldLog(LogLevelInfo) {
|
|
endTime := time.Now()
|
|
ct.conn.log(ctx, LogLevelInfo, "Copy", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected})
|
|
}
|
|
} else if ct.conn.shouldLog(LogLevelError) {
|
|
ct.conn.log(ctx, LogLevelError, "Copy", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames})
|
|
}
|
|
|
|
return rowsAffected, err
|
|
}
|
|
|
|
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
|
|
|
|
for ct.rowSrc.Next() {
|
|
values, err := ct.rowSrc.Values()
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
if len(values) != len(ct.columnNames) {
|
|
return false, nil, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
|
}
|
|
|
|
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
|
for i, val := range values {
|
|
buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val)
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
}
|
|
|
|
if len(buf) > 65536 {
|
|
return true, buf, nil
|
|
}
|
|
}
|
|
|
|
return false, buf, nil
|
|
}
|
|
|
|
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
|
|
// It returns the number of rows copied and an error.
|
|
//
|
|
// CopyFrom requires all values use the binary format. Almost all types
|
|
// implemented by pgx use the binary format by default. Types implementing
|
|
// Encoder can only be used if they encode to the binary format.
|
|
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
|
|
ct := ©From{
|
|
conn: c,
|
|
tableName: tableName,
|
|
columnNames: columnNames,
|
|
rowSrc: rowSrc,
|
|
readerErrChan: make(chan error),
|
|
}
|
|
|
|
return ct.run(ctx)
|
|
}
|