Handle extended protocol with too many arguments
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -720,10 +721,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
||||
return result
|
||||
}
|
||||
|
||||
if len(paramValues) > math.MaxUint16 {
|
||||
result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
return result
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
result.concludeCommand("", ctx.Err())
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
return result
|
||||
default:
|
||||
}
|
||||
@@ -776,10 +785,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
||||
return result
|
||||
}
|
||||
|
||||
if len(paramValues) > math.MaxUint16 {
|
||||
result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
return result
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
result.concludeCommand("", ctx.Err())
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
return result
|
||||
default:
|
||||
}
|
||||
|
||||
+106
@@ -9,9 +9,11 @@ import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -379,6 +381,52 @@ func TestConnExecParams(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
paramCount := math.MaxUint16
|
||||
params := make([]string, 0, paramCount)
|
||||
args := make([][]byte, 0, paramCount)
|
||||
for i := 0; i < paramCount; i++ {
|
||||
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
||||
args = append(args, []byte(strconv.Itoa(i)))
|
||||
}
|
||||
sql := "values" + strings.Join(params, ", ")
|
||||
|
||||
result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
require.Len(t, result.Rows, paramCount)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecParamsTooManyParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
paramCount := math.MaxUint16 + 1
|
||||
params := make([]string, 0, paramCount)
|
||||
args := make([][]byte, 0, paramCount)
|
||||
for i := 0; i < paramCount; i++ {
|
||||
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
||||
args = append(args, []byte(strconv.Itoa(i)))
|
||||
}
|
||||
sql := "values" + strings.Join(params, ", ")
|
||||
|
||||
result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
|
||||
require.Error(t, result.Err)
|
||||
require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error())
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecParamsCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -428,6 +476,64 @@ func TestConnExecPrepared(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecPreparedMaxNumberOfParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
paramCount := math.MaxUint16
|
||||
params := make([]string, 0, paramCount)
|
||||
args := make([][]byte, 0, paramCount)
|
||||
for i := 0; i < paramCount; i++ {
|
||||
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
||||
args = append(args, []byte(strconv.Itoa(i)))
|
||||
}
|
||||
sql := "values" + strings.Join(params, ", ")
|
||||
|
||||
psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
assert.Len(t, psd.ParamOIDs, paramCount)
|
||||
assert.Len(t, psd.Fields, 1)
|
||||
|
||||
result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
require.Len(t, result.Rows, paramCount)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecPreparedTooManyParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
paramCount := math.MaxUint16 + 1
|
||||
params := make([]string, 0, paramCount)
|
||||
args := make([][]byte, 0, paramCount)
|
||||
for i := 0; i < paramCount; i++ {
|
||||
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
||||
args = append(args, []byte(strconv.Itoa(i)))
|
||||
}
|
||||
sql := "values" + strings.Join(params, ", ")
|
||||
|
||||
psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
assert.Len(t, psd.ParamOIDs, paramCount)
|
||||
assert.Len(t, psd.Fields, 1)
|
||||
|
||||
result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read()
|
||||
require.Error(t, result.Err)
|
||||
require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error())
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecPreparedCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user