From 408837dcb1e5fb4535ab313178a64a6ad79d9bbb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Apr 2019 11:47:31 -0500 Subject: [PATCH] Handle extended protocol with too many arguments --- pgconn.go | 17 ++++++++ pgconn_test.go | 106 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+) diff --git a/pgconn.go b/pgconn.go index e246bcdd..223b8e3d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "math" "net" "strconv" "strings" @@ -720,10 +721,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result } + if len(paramValues) > math.MaxUint16 { + result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.closed = true + pgConn.unlock() + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true + pgConn.unlock() return result default: } @@ -776,10 +785,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } + if len(paramValues) > math.MaxUint16 { + result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.closed = true + pgConn.unlock() + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true + pgConn.unlock() return result default: } diff --git a/pgconn_test.go b/pgconn_test.go index ab8ae173..b2514e48 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -9,9 +9,11 @@ import ( "io" "io/ioutil" "log" + "math" "net" "os" "strconv" + "strings" "testing" "time" @@ -379,6 +381,52 @@ func TestConnExecParams(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecParamsMaxNumberOfParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsTooManyParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + + ensureConnValid(t, pgConn) +} + func TestConnExecParamsCanceled(t *testing.T) { t.Parallel() @@ -428,6 +476,64 @@ func TestConnExecPrepared(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedTooManyParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + + ensureConnValid(t, pgConn) +} + func TestConnExecPreparedCanceled(t *testing.T) { t.Parallel()