2
0

Add RowReader.CopyBytes

Implement SelectValueTo in terms of RowReader.CopyBytes
This commit is contained in:
Jack Christensen
2014-07-05 07:50:46 -05:00
parent a1fc6f513a
commit b27d828311
2 changed files with 79 additions and 68 deletions
+40 -62
View File
@@ -333,67 +333,26 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{})
} }
}() }()
err = c.sendQuery(sql, arguments...)
if err != nil {
return err
}
var numRowsFound int64 var numRowsFound int64
var softErr error
for { qr, _ := c.Query(sql, arguments...)
var t byte
var r *MsgReader
t, r, err = c.rxMsg() for qr.NextRow() {
if err != nil { if len(qr.fields) != 1 {
return err qr.Close()
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(qr.fields))}
} }
if t == dataRow { numRowsFound++
numRowsFound++ if numRowsFound != 1 {
qr.Close()
if numRowsFound > 1 { return NotSingleRowError{RowCount: numRowsFound}
softErr = NotSingleRowError{RowCount: numRowsFound}
}
if softErr != nil {
// Read and discard rest of message
continue
}
softErr = c.rxDataRowValueTo(w, r)
} else {
switch t {
case readyForQuery:
c.rxReadyForQuery(r)
return softErr
case rowDescription:
case commandComplete:
case bindComplete:
default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
softErr = e
}
}
} }
var rr RowReader
rr.CopyBytes(qr, w)
} }
} return qr.Err()
func (c *Conn) rxDataRowValueTo(w io.Writer, r *MsgReader) error {
columnCount := r.ReadInt16()
if columnCount != 1 {
return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: columnCount}
}
valueSize := r.ReadInt32()
if valueSize == -1 {
return errors.New("SelectValueTo cannot handle null")
}
r.CopyN(w, valueSize)
return r.Err()
} }
// Prepare creates a prepared statement with name and sql. sql can contain placeholders // Prepare creates a prepared statement with name and sql. sql can contain placeholders
@@ -554,9 +513,9 @@ func (rr *RowReader) ReadInt32(qr *QueryResult) int32 {
return 0 return 0
} }
// TODO - do something about nulls
if size == -1 { if size == -1 {
panic("Can't handle nulls") qr.Fatal(errors.New("Unexpected null"))
return 0
} }
return decodeInt4(qr, fd, size) return decodeInt4(qr, fd, size)
@@ -568,9 +527,9 @@ func (rr *RowReader) ReadInt64(qr *QueryResult) int64 {
return 0 return 0
} }
// TODO - do something about nulls
if size == -1 { if size == -1 {
panic("Can't handle nulls") qr.Fatal(errors.New("Unexpected null"))
return 0
} }
return decodeInt8(qr, fd, size) return decodeInt8(qr, fd, size)
@@ -584,9 +543,9 @@ func (rr *RowReader) ReadTime(qr *QueryResult) time.Time {
return zeroTime return zeroTime
} }
// TODO - do something about nulls
if size == -1 { if size == -1 {
panic("Can't handle nulls") qr.Fatal(errors.New("Unexpected null"))
return zeroTime
} }
return decodeTimestampTz(qr, fd, size) return decodeTimestampTz(qr, fd, size)
@@ -600,9 +559,9 @@ func (rr *RowReader) ReadDate(qr *QueryResult) time.Time {
return zeroTime return zeroTime
} }
// TODO - do something about nulls
if size == -1 { if size == -1 {
panic("Can't handle nulls") qr.Fatal(errors.New("Unexpected null"))
return zeroTime
} }
return decodeDate(qr, fd, size) return decodeDate(qr, fd, size)
@@ -614,6 +573,11 @@ func (rr *RowReader) ReadString(qr *QueryResult) string {
return "" return ""
} }
if size == -1 {
qr.Fatal(errors.New("Unexpected null"))
return ""
}
return decodeText(qr, fd, size) return decodeText(qr, fd, size)
} }
@@ -634,6 +598,20 @@ func (rr *RowReader) ReadValue(qr *QueryResult) interface{} {
} }
} }
func (rr *RowReader) CopyBytes(qr *QueryResult, w io.Writer) {
_, size, ok := qr.NextColumn()
if !ok {
return
}
if size == -1 {
qr.Fatal(errors.New("Unexpected null"))
return
}
qr.MsgReader().CopyN(w, size)
}
type QueryResult struct { type QueryResult struct {
pool *ConnPool pool *ConnPool
conn *Conn conn *Conn
+39 -6
View File
@@ -462,6 +462,42 @@ func TestConnQueryReadTooManyValues(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestQueryResultCopyBytes(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var mimeType string
var buf bytes.Buffer
qr, err := conn.Query("select 'application/json', '[1,2,3,4,5]'::json")
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
for qr.NextRow() {
var rr pgx.RowReader
mimeType = rr.ReadString(qr)
rr.CopyBytes(qr, &buf)
}
qr.Close()
if qr.Err() != nil {
t.Fatalf("conn.Query failed: ", err)
}
if mimeType != "application/json" {
t.Errorf(`Expected mimeType to be "application/json", but it was "%v"`, mimeType)
}
if bytes.Compare(buf.Bytes(), []byte("[1,2,3,4,5]")) != 0 {
t.Fatalf("CopyBytes did not write expected data: %v", string(buf.Bytes()))
}
ensureConnValid(t, conn)
}
func TestConnectionSelectValue(t *testing.T) { func TestConnectionSelectValue(t *testing.T) {
t.Parallel() t.Parallel()
@@ -546,14 +582,11 @@ func TestConnectionSelectValueTo(t *testing.T) {
// Null // Null
err = conn.SelectValueTo(&buf, "select null") err = conn.SelectValueTo(&buf, "select null")
if err == nil || err.Error() != "SelectValueTo cannot handle null" { if err == nil || err.Error() != "Unexpected null" {
t.Fatalf("Expected null error: %#v", err) t.Fatalf("Expected null error: %#v", err)
} }
if conn.IsAlive() {
mustSelectValue(t, conn, "select 1") // ensure it really is alive and usable ensureConnValid(t, conn)
} else {
t.Fatal("SelectValueTo null error should not have killed connection")
}
} }
func TestPrepare(t *testing.T) { func TestPrepare(t *testing.T) {