diff --git a/conn.go b/conn.go index 99b4c05f..f549e03e 100644 --- a/conn.go +++ b/conn.go @@ -1458,6 +1458,10 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, return "", err } } else if options != nil && len(options.ParameterOIDs) > 0 { + if err := c.ensureConnectionReadyForQuery(); err != nil { + return "", err + } + buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments) if err != nil { return "", err diff --git a/conn_test.go b/conn_test.go index d9369a1a..1996f814 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1196,6 +1196,34 @@ func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { } } +func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table foo(name varchar primary key);") + + var s string + err := conn.QueryRow("insert into foo(name) values('baz') returning name;").Scan(&s) + if err != nil { + t.Errorf("Executing query failed: %v", err) + } + if s != "baz" { + t.Errorf("Query did not return expected value: %v", s) + } + + _, err = conn.ExecEx( + context.Background(), + "insert into foo(name) values($1);", + &pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}}, + "bar'; drop table foo;--", + ) + if err == nil { + t.Fatal("expected error but got none") + } +} + func TestPrepare(t *testing.T) { t.Parallel()