diff --git a/conn.go b/conn.go index e9641508..6b210043 100644 --- a/conn.go +++ b/conn.go @@ -90,54 +90,30 @@ func (c *conn) Close() (err error) { return } -func (c *conn) Query(sql string) (rows []map[string]string, err error) { +func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) { if err = c.sendSimpleQuery(sql); err != nil { return } + var callbackError error var fields []fieldDescription - rows = make([]map[string]string, 0) + for { var t byte var r *messageReader if t, r, err = c.rxMsg(); err == nil { switch t { case readyForQuery: - return rows, nil + if err == nil { + err = callbackError + } + return case rowDescription: fields = c.rxRowDescription(r) case dataRow: - rows = append(rows, c.rxDataRow(r, fields)) - case commandComplete: - c.rxCommandComplete(r) - default: - if err = c.processContextFreeMsg(t, r); err != nil { - return nil, err + if callbackError == nil { + callbackError = onDataRow(r, fields) } - } - } else { - return nil, err - } - } - - panic("Unreachable") -} - -func (c *conn) selectOne(sql string) (s string, err error) { - if err = c.sendSimpleQuery(sql); err != nil { - return - } - - for { - var t byte - var r *messageReader - if t, r, err = c.rxMsg(); err == nil { - switch t { - case readyForQuery: - return - case rowDescription: - case dataRow: - s = c.rxDataRowFirstValue(r) case commandComplete: default: if err = c.processContextFreeMsg(t, r); err != nil { @@ -152,13 +128,28 @@ func (c *conn) selectOne(sql string) (s string, err error) { panic("Unreachable") } +func (c *conn) Query(sql string) (rows []map[string]string, err error) { + rows = make([]map[string]string, 0, 8) + onDataRow := func(r *messageReader, fields []fieldDescription) error { + rows = append(rows, c.rxDataRow(r, fields)) + return nil + } + err = c.query(sql, onDataRow) + return +} + func (c *conn) SelectString(sql string) (s string, err error) { - return c.selectOne(sql) + onDataRow := func(r *messageReader, _ []fieldDescription) error { + s = c.rxDataRowFirstValue(r) + return nil + } + err = c.query(sql, onDataRow) + return } func (c *conn) selectInt(sql string, size int) (i int64, err error) { var s string - s, err = c.selectOne(sql) + s, err = c.SelectString(sql) if err != nil { return } @@ -187,7 +178,7 @@ func (c *conn) SelectInt16(sql string) (i int16, err error) { func (c *conn) selectFloat(sql string, size int) (f float64, err error) { var s string - s, err = c.selectOne(sql) + s, err = c.SelectString(sql) if err != nil { return } @@ -208,107 +199,72 @@ func (c *conn) SelectFloat32(sql string) (f float32, err error) { } func (c *conn) SelectAllString(sql string) (strings []string, err error) { - if err = c.sendSimpleQuery(sql); err != nil { - return + strings = make([]string, 0, 8) + onDataRow := func(r *messageReader, _ []fieldDescription) error { + strings = append(strings, c.rxDataRowFirstValue(r)) + return nil } - - strings = make([]string, 0) - - for { - var t byte - var r *messageReader - if t, r, err = c.rxMsg(); err == nil { - switch t { - case readyForQuery: - return - case rowDescription: - case dataRow: - strings = append(strings, c.rxDataRowFirstValue(r)) - case commandComplete: - default: - if err = c.processContextFreeMsg(t, r); err != nil { - return - } - } - } else { - return - } - } - - panic("Unreachable") -} - -func (c *conn) selectAllInt(sql string, size int) (ints []int64, err error) { - var strings []string - strings, err = c.SelectAllString(sql) - if err != nil { - return - } - - ints = make([]int64, len(strings)) - for i, s := range strings { - ints[i], err = strconv.ParseInt(s, 10, size) - if err != nil { - return - } - } - + err = c.query(sql, onDataRow) return } func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) { - return c.selectAllInt(sql, 64) + ints = make([]int64, 0, 8) + onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + var i int64 + i, parseError = strconv.ParseInt(c.rxDataRowFirstValue(r), 10, 64) + ints = append(ints, i) + return + } + err = c.query(sql, onDataRow) + return } func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) { - var int64s []int64 - int64s, err = c.selectAllInt(sql, 32) - ints = make([]int32, len(int64s)) - for i := 0; i < len(int64s); i++ { - ints[i] = int32(int64s[i]) + ints = make([]int32, 0, 8) + onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) { + var i int64 + i, parseError = strconv.ParseInt(c.rxDataRowFirstValue(r), 10, 32) + ints = append(ints, int32(i)) + return } + err = c.query(sql, onDataRow) return } func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) { - var int64s []int64 - int64s, err = c.selectAllInt(sql, 16) - ints = make([]int16, len(int64s)) - for i := 0; i < len(int64s); i++ { - ints[i] = int16(int64s[i]) - } - return -} - -func (c *conn) selectAllFloat(sql string, size int) (floats []float64, err error) { - var strings []string - strings, err = c.SelectAllString(sql) - if err != nil { + ints = make([]int16, 0, 8) + onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + var i int64 + i, parseError = strconv.ParseInt(c.rxDataRowFirstValue(r), 10, 16) + ints = append(ints, int16(i)) return } - - floats = make([]float64, len(strings)) - for i, s := range strings { - floats[i], err = strconv.ParseFloat(s, size) - if err != nil { - return - } - } - + err = c.query(sql, onDataRow) return } func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) { - return c.selectAllFloat(sql, 64) + floats = make([]float64, 0, 8) + onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + var f float64 + f, parseError = strconv.ParseFloat(c.rxDataRowFirstValue(r), 64) + floats = append(floats, f) + return + } + err = c.query(sql, onDataRow) + return } func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) { - var float64s []float64 - float64s, err = c.selectAllFloat(sql, 32) - floats = make([]float32, len(float64s)) - for i := 0; i < len(float64s); i++ { - floats[i] = float32(float64s[i]) + floats = make([]float32, 0, 8) + onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + var f float64 + f, parseError = strconv.ParseFloat(c.rxDataRowFirstValue(r), 32) + floats = append(floats, float32(f)) + return } + err = c.query(sql, onDataRow) return }