Handle extended protocol with too many arguments
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -720,10 +721,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(paramValues) > math.MaxUint16 {
|
||||||
|
result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
|
||||||
|
result.closed = true
|
||||||
|
pgConn.unlock()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
result.concludeCommand("", ctx.Err())
|
result.concludeCommand("", ctx.Err())
|
||||||
result.closed = true
|
result.closed = true
|
||||||
|
pgConn.unlock()
|
||||||
return result
|
return result
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -776,10 +785,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(paramValues) > math.MaxUint16 {
|
||||||
|
result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
|
||||||
|
result.closed = true
|
||||||
|
pgConn.unlock()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
result.concludeCommand("", ctx.Err())
|
result.concludeCommand("", ctx.Err())
|
||||||
result.closed = true
|
result.closed = true
|
||||||
|
pgConn.unlock()
|
||||||
return result
|
return result
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|||||||
+106
@@ -9,9 +9,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -379,6 +381,52 @@ func TestConnExecParams(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
paramCount := math.MaxUint16
|
||||||
|
params := make([]string, 0, paramCount)
|
||||||
|
args := make([][]byte, 0, paramCount)
|
||||||
|
for i := 0; i < paramCount; i++ {
|
||||||
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
||||||
|
args = append(args, []byte(strconv.Itoa(i)))
|
||||||
|
}
|
||||||
|
sql := "values" + strings.Join(params, ", ")
|
||||||
|
|
||||||
|
result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
|
||||||
|
require.NoError(t, result.Err)
|
||||||
|
require.Len(t, result.Rows, paramCount)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnExecParamsTooManyParams(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
paramCount := math.MaxUint16 + 1
|
||||||
|
params := make([]string, 0, paramCount)
|
||||||
|
args := make([][]byte, 0, paramCount)
|
||||||
|
for i := 0; i < paramCount; i++ {
|
||||||
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
||||||
|
args = append(args, []byte(strconv.Itoa(i)))
|
||||||
|
}
|
||||||
|
sql := "values" + strings.Join(params, ", ")
|
||||||
|
|
||||||
|
result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
|
||||||
|
require.Error(t, result.Err)
|
||||||
|
require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error())
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnExecParamsCanceled(t *testing.T) {
|
func TestConnExecParamsCanceled(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -428,6 +476,64 @@ func TestConnExecPrepared(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnExecPreparedMaxNumberOfParams(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
paramCount := math.MaxUint16
|
||||||
|
params := make([]string, 0, paramCount)
|
||||||
|
args := make([][]byte, 0, paramCount)
|
||||||
|
for i := 0; i < paramCount; i++ {
|
||||||
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
||||||
|
args = append(args, []byte(strconv.Itoa(i)))
|
||||||
|
}
|
||||||
|
sql := "values" + strings.Join(params, ", ")
|
||||||
|
|
||||||
|
psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
assert.Len(t, psd.ParamOIDs, paramCount)
|
||||||
|
assert.Len(t, psd.Fields, 1)
|
||||||
|
|
||||||
|
result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read()
|
||||||
|
require.NoError(t, result.Err)
|
||||||
|
require.Len(t, result.Rows, paramCount)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnExecPreparedTooManyParams(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
paramCount := math.MaxUint16 + 1
|
||||||
|
params := make([]string, 0, paramCount)
|
||||||
|
args := make([][]byte, 0, paramCount)
|
||||||
|
for i := 0; i < paramCount; i++ {
|
||||||
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
||||||
|
args = append(args, []byte(strconv.Itoa(i)))
|
||||||
|
}
|
||||||
|
sql := "values" + strings.Join(params, ", ")
|
||||||
|
|
||||||
|
psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
assert.Len(t, psd.ParamOIDs, paramCount)
|
||||||
|
assert.Len(t, psd.Fields, 1)
|
||||||
|
|
||||||
|
result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read()
|
||||||
|
require.Error(t, result.Err)
|
||||||
|
require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error())
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnExecPreparedCanceled(t *testing.T) {
|
func TestConnExecPreparedCanceled(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user