diff --git a/.gitignore b/.gitignore index 348e014f..a2ebbe9c 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ _testmain.go .envrc /.testdb + +.DS_Store diff --git a/conn.go b/conn.go index b6c1ab8f..b4bbdb5b 100644 --- a/conn.go +++ b/conn.go @@ -721,43 +721,10 @@ optionLoop: sd, explicitPreparedStatement := c.preparedStatements[sql] if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec { if sd == nil { - switch mode { - case QueryExecModeCacheStatement: - if c.statementCache == nil { - err = errDisabledStatementCache - rows.fatal(err) - return rows, err - } - sd = c.statementCache.Get(sql) - if sd == nil { - sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) - if err != nil { - rows.fatal(err) - return rows, err - } - c.statementCache.Put(sd) - } - case QueryExecModeCacheDescribe: - if c.descriptionCache == nil { - err = errDisabledDescriptionCache - rows.fatal(err) - return rows, err - } - sd = c.descriptionCache.Get(sql) - if sd == nil { - sd, err = c.Prepare(ctx, "", sql) - if err != nil { - rows.fatal(err) - return rows, err - } - c.descriptionCache.Put(sd) - } - case QueryExecModeDescribeExec: - sd, err = c.Prepare(ctx, "", sql) - if err != nil { - rows.fatal(err) - return rows, err - } + sd, err = c.getStatementDescription(ctx, mode, sql) + if err != nil { + rows.fatal(err) + return rows, err } } @@ -827,6 +794,48 @@ optionLoop: return rows, rows.err } +// getStatementDescription returns the statement description of the sql query +// according to the given mode. +// +// If the mode is one that doesn't require to know the param and result OIDs +// then nil is returned without error. +func (c *Conn) getStatementDescription( + ctx context.Context, + mode QueryExecMode, + sql string, +) (sd *pgconn.StatementDescription, err error) { + + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + return nil, errDisabledStatementCache + } + sd = c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) + if err != nil { + return nil, err + } + c.statementCache.Put(sd) + } + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + return nil, errDisabledDescriptionCache + } + sd = c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + return nil, err + } + c.descriptionCache.Put(sd) + } + case QueryExecModeDescribeExec: + return c.Prepare(ctx, "", sql) + } + return sd, err +} + // QueryRow is a convenience wrapper over Query. Any error that occurs while // querying is deferred until calling Scan on the returned Row. That Row will // error with ErrNoRows if no rows are returned. diff --git a/copy_from.go b/copy_from.go index 41acafdc..40232bdf 100644 --- a/copy_from.go +++ b/copy_from.go @@ -85,6 +85,7 @@ type copyFrom struct { columnNames []string rowSrc CopyFromSource readerErrChan chan error + mode QueryExecMode } func (ct *copyFrom) run(ctx context.Context) (int64, error) { @@ -105,9 +106,29 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { } quotedColumnNames := cbuf.String() - sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) - if err != nil { - return 0, err + var sd *pgconn.StatementDescription + switch ct.mode { + case QueryExecModeExec, QueryExecModeSimpleProtocol: + // These modes don't support the binary format. Before the inclusion of the + // QueryExecModes, Conn.Prepare was called on every COPY operation to get + // the OIDs. These prepared statements were not cached. + // + // Since that's the same behavior provided by QueryExecModeDescribeExec, + // we'll default to that mode. + ct.mode = QueryExecModeDescribeExec + fallthrough + case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec: + var err error + sd, err = ct.conn.getStatementDescription( + ctx, + ct.mode, + fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName), + ) + if err != nil { + return 0, fmt.Errorf("statement description failed: %w", err) + } + default: + return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode) } r, w := io.Pipe() @@ -208,6 +229,7 @@ func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames [ columnNames: columnNames, rowSrc: rowSrc, readerErrChan: make(chan error), + mode: c.config.DefaultQueryExecMode, } return ct.run(ctx) diff --git a/copy_from_test.go b/copy_from_test.go index 49bfcb34..1342c14c 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -14,6 +14,129 @@ import ( "github.com/stretchr/testify/require" ) +func TestConnCopyWithAllQueryExecModes(t *testing.T) { + for _, mode := range pgxtest.AllQueryExecModes { + t.Run(mode.String(), func(t *testing.T) { + t.Parallel() + + cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + cfg.DefaultQueryExecMode = mode + conn := mustConnect(t, cfg) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d text, + e timestamptz + )`) + + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]any{ + {int16(0), int32(1), int64(2), "abc", tzedTime}, + {nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if int(copyCount) != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query(context.Background(), "select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]any + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) + }) + } +} + +func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) { + + for _, mode := range pgxtest.KnownOIDQueryExecModes { + t.Run(mode.String(), func(t *testing.T) { + t.Parallel() + + cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + cfg.DefaultQueryExecMode = mode + conn := mustConnect(t, cfg) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]any{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if int(copyCount) != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query(context.Background(), "select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]any + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) + }) + } +} + func TestConnCopyFromSmall(t *testing.T) { t.Parallel() @@ -220,7 +343,7 @@ func TestConnCopyFromEnum(t *testing.T) { conn.TypeMap().RegisterType(typ) } - _, err = tx.Exec(ctx, `create table foo( + _, err = tx.Exec(ctx, `create temporary table foo( a text, b color, c fruit, diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 9b0895d1..3b1ce436 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -18,6 +18,8 @@ import ( "testing" "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgmock" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" @@ -1666,6 +1668,59 @@ func TestConnCopyFrom(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCopyFromBinary(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + buf := []byte{} + buf = append(buf, "PGCOPY\n\377\r\n\000"...) + buf = pgio.AppendInt32(buf, 0) + buf = pgio.AppendInt32(buf, 0) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + // Number of elements in the tuple + buf = pgio.AppendInt16(buf, int16(2)) + a := i + + // Length of element for column `a int4` + buf = pgio.AppendInt32(buf, 4) + buf, err = pgtype.NewMap().Encode(pgtype.Int4OID, pgx.BinaryFormatCode, a, buf) + require.NoError(t, err) + + b := "foo " + strconv.Itoa(a) + " bar" + lenB := int32(len([]byte(b))) + // Length of element for column `b varchar` + buf = pgio.AppendInt32(buf, lenB) + buf, err = pgtype.NewMap().Encode(pgtype.VarcharOID, pgx.BinaryFormatCode, b, buf) + require.NoError(t, err) + + inputRows = append(inputRows, [][]byte{[]byte(strconv.Itoa(a)), []byte(b)}) + } + + srcBuf := &bytes.Buffer{} + srcBuf.Write(buf) + ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo (a, b) FROM STDIN BINARY;") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + func TestConnCopyFromCanceled(t *testing.T) { t.Parallel() diff --git a/tracelog/tracelog_test.go b/tracelog/tracelog_test.go index 96b79de9..5458eaea 100644 --- a/tracelog/tracelog_test.go +++ b/tracelog/tracelog_test.go @@ -217,7 +217,7 @@ func TestLogCopyFrom(t *testing.T) { return config } - pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, pgxtest.KnownOIDQueryExecModes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { _, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`) require.NoError(t, err)