Add RowReader.CopyBytes
Implement SelectValueTo in terms of RowReader.CopyBytes
This commit is contained in:
@@ -333,67 +333,26 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = c.sendQuery(sql, arguments...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var numRowsFound int64
|
var numRowsFound int64
|
||||||
var softErr error
|
|
||||||
|
|
||||||
for {
|
qr, _ := c.Query(sql, arguments...)
|
||||||
var t byte
|
|
||||||
var r *MsgReader
|
|
||||||
|
|
||||||
t, r, err = c.rxMsg()
|
for qr.NextRow() {
|
||||||
if err != nil {
|
if len(qr.fields) != 1 {
|
||||||
return err
|
qr.Close()
|
||||||
|
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(qr.fields))}
|
||||||
}
|
}
|
||||||
|
|
||||||
if t == dataRow {
|
numRowsFound++
|
||||||
numRowsFound++
|
if numRowsFound != 1 {
|
||||||
|
qr.Close()
|
||||||
if numRowsFound > 1 {
|
return NotSingleRowError{RowCount: numRowsFound}
|
||||||
softErr = NotSingleRowError{RowCount: numRowsFound}
|
|
||||||
}
|
|
||||||
|
|
||||||
if softErr != nil {
|
|
||||||
// Read and discard rest of message
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
softErr = c.rxDataRowValueTo(w, r)
|
|
||||||
} else {
|
|
||||||
switch t {
|
|
||||||
case readyForQuery:
|
|
||||||
c.rxReadyForQuery(r)
|
|
||||||
return softErr
|
|
||||||
case rowDescription:
|
|
||||||
case commandComplete:
|
|
||||||
case bindComplete:
|
|
||||||
default:
|
|
||||||
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
|
|
||||||
softErr = e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var rr RowReader
|
||||||
|
rr.CopyBytes(qr, w)
|
||||||
}
|
}
|
||||||
}
|
return qr.Err()
|
||||||
|
|
||||||
func (c *Conn) rxDataRowValueTo(w io.Writer, r *MsgReader) error {
|
|
||||||
columnCount := r.ReadInt16()
|
|
||||||
if columnCount != 1 {
|
|
||||||
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
|
|
||||||
}
|
|
||||||
|
|
||||||
valueSize := r.ReadInt32()
|
|
||||||
if valueSize == -1 {
|
|
||||||
return errors.New("SelectValueTo cannot handle null")
|
|
||||||
}
|
|
||||||
|
|
||||||
r.CopyN(w, valueSize)
|
|
||||||
|
|
||||||
return r.Err()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
|
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
|
||||||
@@ -554,9 +513,9 @@ func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - do something about nulls
|
|
||||||
if size == -1 {
|
if size == -1 {
|
||||||
panic("Can't handle nulls")
|
qr.Fatal(errors.New("Unexpected null"))
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return decodeInt4(qr, fd, size)
|
return decodeInt4(qr, fd, size)
|
||||||
@@ -568,9 +527,9 @@ func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - do something about nulls
|
|
||||||
if size == -1 {
|
if size == -1 {
|
||||||
panic("Can't handle nulls")
|
qr.Fatal(errors.New("Unexpected null"))
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return decodeInt8(qr, fd, size)
|
return decodeInt8(qr, fd, size)
|
||||||
@@ -584,9 +543,9 @@ func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
|
|||||||
return zeroTime
|
return zeroTime
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - do something about nulls
|
|
||||||
if size == -1 {
|
if size == -1 {
|
||||||
panic("Can't handle nulls")
|
qr.Fatal(errors.New("Unexpected null"))
|
||||||
|
return zeroTime
|
||||||
}
|
}
|
||||||
|
|
||||||
return decodeTimestampTz(qr, fd, size)
|
return decodeTimestampTz(qr, fd, size)
|
||||||
@@ -600,9 +559,9 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
|
|||||||
return zeroTime
|
return zeroTime
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - do something about nulls
|
|
||||||
if size == -1 {
|
if size == -1 {
|
||||||
panic("Can't handle nulls")
|
qr.Fatal(errors.New("Unexpected null"))
|
||||||
|
return zeroTime
|
||||||
}
|
}
|
||||||
|
|
||||||
return decodeDate(qr, fd, size)
|
return decodeDate(qr, fd, size)
|
||||||
@@ -614,6 +573,11 @@ func (rr *RowReader) ReadString(qr *QueryResult) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if size == -1 {
|
||||||
|
qr.Fatal(errors.New("Unexpected null"))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
return decodeText(qr, fd, size)
|
return decodeText(qr, fd, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -634,6 +598,20 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rr *RowReader) CopyBytes(qr *QueryResult, w io.Writer) {
|
||||||
|
_, size, ok := qr.NextColumn()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if size == -1 {
|
||||||
|
qr.Fatal(errors.New("Unexpected null"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
qr.MsgReader().CopyN(w, size)
|
||||||
|
}
|
||||||
|
|
||||||
type QueryResult struct {
|
type QueryResult struct {
|
||||||
pool *ConnPool
|
pool *ConnPool
|
||||||
conn *Conn
|
conn *Conn
|
||||||
|
|||||||
+39
-6
@@ -462,6 +462,42 @@ func TestConnQueryReadTooManyValues(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryResultCopyBytes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
var mimeType string
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
qr, err := conn.Query("select 'application/json', '[1,2,3,4,5]'::json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for qr.NextRow() {
|
||||||
|
var rr pgx.RowReader
|
||||||
|
mimeType = rr.ReadString(qr)
|
||||||
|
rr.CopyBytes(qr, &buf)
|
||||||
|
}
|
||||||
|
qr.Close()
|
||||||
|
|
||||||
|
if qr.Err() != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mimeType != "application/json" {
|
||||||
|
t.Errorf(`Expected mimeType to be "application/json", but it was "%v"`, mimeType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Compare(buf.Bytes(), []byte("[1,2,3,4,5]")) != 0 {
|
||||||
|
t.Fatalf("CopyBytes did not write expected data: %v", string(buf.Bytes()))
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnectionSelectValue(t *testing.T) {
|
func TestConnectionSelectValue(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -546,14 +582,11 @@ func TestConnectionSelectValueTo(t *testing.T) {
|
|||||||
|
|
||||||
// Null
|
// Null
|
||||||
err = conn.SelectValueTo(&buf, "select null")
|
err = conn.SelectValueTo(&buf, "select null")
|
||||||
if err == nil || err.Error() != "SelectValueTo cannot handle null" {
|
if err == nil || err.Error() != "Unexpected null" {
|
||||||
t.Fatalf("Expected null error: %#v", err)
|
t.Fatalf("Expected null error: %#v", err)
|
||||||
}
|
}
|
||||||
if conn.IsAlive() {
|
|
||||||
mustSelectValue(t, conn, "select 1") // ensure it really is alive and usable
|
ensureConnValid(t, conn)
|
||||||
} else {
|
|
||||||
t.Fatal("SelectValueTo null error should not have killed connection")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrepare(t *testing.T) {
|
func TestPrepare(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user