Build / rewrite / port multirange support
This commit is contained in:
@@ -0,0 +1,443 @@
|
||||
package pgtype
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
// MultirangeGetter is a type that can be converted into a PostgreSQL multirange.
|
||||
type MultirangeGetter interface {
|
||||
// IsNull returns true if the value is SQL NULL.
|
||||
IsNull() bool
|
||||
|
||||
// Len returns the number of elements in the multirange.
|
||||
Len() int
|
||||
|
||||
// Index returns the element at i.
|
||||
Index(i int) any
|
||||
|
||||
// IndexType returns a non-nil scan target of the type Index will return. This is used by MultirangeCodec.PlanEncode.
|
||||
IndexType() any
|
||||
}
|
||||
|
||||
// MultirangeSetter is a type can be set from a PostgreSQL multirange.
|
||||
type MultirangeSetter interface {
|
||||
// ScanNull sets the value to SQL NULL.
|
||||
ScanNull() error
|
||||
|
||||
// SetLen prepares the value such that ScanIndex can be called for each element. This will remove any existing
|
||||
// elements.
|
||||
SetLen(n int) error
|
||||
|
||||
// ScanIndex returns a value usable as a scan target for i. SetLen must be called before ScanIndex.
|
||||
ScanIndex(i int) any
|
||||
|
||||
// ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by
|
||||
// MultirangeCodec.PlanScan.
|
||||
ScanIndexType() any
|
||||
}
|
||||
|
||||
// MultirangeCodec is a codec for any multirange type.
|
||||
type MultirangeCodec struct {
|
||||
ElementType *Type
|
||||
}
|
||||
|
||||
func (c *MultirangeCodec) FormatSupported(format int16) bool {
|
||||
return c.ElementType.Codec.FormatSupported(format)
|
||||
}
|
||||
|
||||
func (c *MultirangeCodec) PreferredFormat() int16 {
|
||||
return c.ElementType.Codec.PreferredFormat()
|
||||
}
|
||||
|
||||
func (c *MultirangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||
multirangeValuer, ok := value.(MultirangeGetter)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
elementType := multirangeValuer.IndexType()
|
||||
|
||||
elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType)
|
||||
if elementEncodePlan == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
return &encodePlanMultirangeCodecBinary{ac: c, m: m, oid: oid}
|
||||
case TextFormatCode:
|
||||
return &encodePlanMultirangeCodecText{ac: c, m: m, oid: oid}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type encodePlanMultirangeCodecText struct {
|
||||
ac *MultirangeCodec
|
||||
m *Map
|
||||
oid uint32
|
||||
}
|
||||
|
||||
func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
multirange := value.(MultirangeGetter)
|
||||
|
||||
if multirange.IsNull() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
elementCount := multirange.Len()
|
||||
|
||||
buf = append(buf, '{')
|
||||
|
||||
var encodePlan EncodePlan
|
||||
var lastElemType reflect.Type
|
||||
inElemBuf := make([]byte, 0, 32)
|
||||
for i := 0; i < elementCount; i++ {
|
||||
if i > 0 {
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
|
||||
elem := multirange.Index(i)
|
||||
var elemBuf []byte
|
||||
if elem != nil {
|
||||
elemType := reflect.TypeOf(elem)
|
||||
if lastElemType != elemType {
|
||||
lastElemType = elemType
|
||||
encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem)
|
||||
if encodePlan == nil {
|
||||
return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
|
||||
}
|
||||
}
|
||||
elemBuf, err = encodePlan.Encode(elem, inElemBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if elemBuf == nil {
|
||||
return nil, fmt.Errorf("multirange cannot contain NULL element")
|
||||
} else {
|
||||
buf = append(buf, elemBuf...)
|
||||
}
|
||||
}
|
||||
|
||||
buf = append(buf, '}')
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
type encodePlanMultirangeCodecBinary struct {
|
||||
ac *MultirangeCodec
|
||||
m *Map
|
||||
oid uint32
|
||||
}
|
||||
|
||||
func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
multirange := value.(MultirangeGetter)
|
||||
|
||||
if multirange.IsNull() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
elementCount := multirange.Len()
|
||||
|
||||
buf = pgio.AppendInt32(buf, int32(elementCount))
|
||||
|
||||
var encodePlan EncodePlan
|
||||
var lastElemType reflect.Type
|
||||
for i := 0; i < elementCount; i++ {
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
|
||||
elem := multirange.Index(i)
|
||||
var elemBuf []byte
|
||||
if elem != nil {
|
||||
elemType := reflect.TypeOf(elem)
|
||||
if lastElemType != elemType {
|
||||
lastElemType = elemType
|
||||
encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem)
|
||||
if encodePlan == nil {
|
||||
return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
|
||||
}
|
||||
}
|
||||
elemBuf, err = encodePlan.Encode(elem, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if elemBuf == nil {
|
||||
return nil, fmt.Errorf("multirange cannot contain NULL element")
|
||||
} else {
|
||||
buf = elemBuf
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||
}
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (c *MultirangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||
multirangeScanner, ok := target.(MultirangeSetter)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
elementType := multirangeScanner.ScanIndexType()
|
||||
|
||||
elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType)
|
||||
if _, ok := elementScanPlan.(*scanPlanFail); ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &scanPlanMultirangeCodec{
|
||||
multirangeCodec: c,
|
||||
m: m,
|
||||
oid: oid,
|
||||
formatCode: format,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
|
||||
rp := 0
|
||||
|
||||
elementCount := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
err := multirange.SetLen(elementCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if elementCount == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
|
||||
if elementScanPlan == nil {
|
||||
elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
|
||||
}
|
||||
|
||||
for i := 0; i < elementCount; i++ {
|
||||
elem := multirange.ScanIndex(i)
|
||||
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
var elemSrc []byte
|
||||
if elemLen >= 0 {
|
||||
elemSrc = src[rp : rp+elemLen]
|
||||
rp += elemLen
|
||||
}
|
||||
err = elementScanPlan.Scan(elemSrc, elem)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan multirange element %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MultirangeCodec) decodeText(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
|
||||
elements, err := parseUntypedTextMultirange(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = multirange.SetLen(len(elements))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(elements) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
|
||||
if elementScanPlan == nil {
|
||||
elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
|
||||
}
|
||||
|
||||
for i, s := range elements {
|
||||
elem := multirange.ScanIndex(i)
|
||||
err = elementScanPlan.Scan([]byte(s), elem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type scanPlanMultirangeCodec struct {
|
||||
multirangeCodec *MultirangeCodec
|
||||
m *Map
|
||||
oid uint32
|
||||
formatCode int16
|
||||
elementScanPlan ScanPlan
|
||||
}
|
||||
|
||||
func (spac *scanPlanMultirangeCodec) Scan(src []byte, dst any) error {
|
||||
c := spac.multirangeCodec
|
||||
m := spac.m
|
||||
oid := spac.oid
|
||||
formatCode := spac.formatCode
|
||||
|
||||
multirange := dst.(MultirangeSetter)
|
||||
|
||||
if src == nil {
|
||||
return multirange.ScanNull()
|
||||
}
|
||||
|
||||
switch formatCode {
|
||||
case BinaryFormatCode:
|
||||
return c.decodeBinary(m, oid, src, multirange)
|
||||
case TextFormatCode:
|
||||
return c.decodeText(m, oid, src, multirange)
|
||||
default:
|
||||
return fmt.Errorf("unknown format code %d", formatCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MultirangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch format {
|
||||
case TextFormatCode:
|
||||
return string(src), nil
|
||||
case BinaryFormatCode:
|
||||
buf := make([]byte, len(src))
|
||||
copy(buf, src)
|
||||
return buf, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown format code %d", format)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MultirangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var multirange Multirange[Range[any]]
|
||||
err := m.PlanScan(oid, format, &multirange).Scan(src, &multirange)
|
||||
return multirange, err
|
||||
}
|
||||
|
||||
func parseUntypedTextMultirange(src []byte) ([]string, error) {
|
||||
elements := make([]string, 0)
|
||||
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
skipWhitespace(buf)
|
||||
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid array: %v", err)
|
||||
}
|
||||
|
||||
if r != '{' {
|
||||
return nil, fmt.Errorf("invalid multirange, expected '{': %v", err)
|
||||
}
|
||||
|
||||
parseValueLoop:
|
||||
for {
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid multirange: %v", err)
|
||||
}
|
||||
|
||||
switch r {
|
||||
case ',': // skip range separator
|
||||
case '}':
|
||||
break parseValueLoop
|
||||
default:
|
||||
buf.UnreadRune()
|
||||
value, err := parseRange(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid multirange value: %v", err)
|
||||
}
|
||||
elements = append(elements, value)
|
||||
}
|
||||
}
|
||||
|
||||
skipWhitespace(buf)
|
||||
|
||||
if buf.Len() > 0 {
|
||||
return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
|
||||
}
|
||||
|
||||
return elements, nil
|
||||
|
||||
}
|
||||
|
||||
func parseRange(buf *bytes.Buffer) (string, error) {
|
||||
s := &bytes.Buffer{}
|
||||
|
||||
boundSepRead := false
|
||||
for {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch r {
|
||||
case ',', '}':
|
||||
if r == ',' && !boundSepRead {
|
||||
boundSepRead = true
|
||||
break
|
||||
}
|
||||
buf.UnreadRune()
|
||||
return s.String(), nil
|
||||
}
|
||||
|
||||
s.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Multirange is a generic multirange type.
|
||||
//
|
||||
// T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to
|
||||
// enforce the RangeScanner constraint.
|
||||
type Multirange[T RangeValuer] []T
|
||||
|
||||
func (r Multirange[T]) IsNull() bool {
|
||||
return r == nil
|
||||
}
|
||||
|
||||
func (r Multirange[T]) Len() int {
|
||||
return len(r)
|
||||
}
|
||||
|
||||
func (r Multirange[T]) Index(i int) any {
|
||||
return r[i]
|
||||
}
|
||||
|
||||
func (r Multirange[T]) IndexType() any {
|
||||
var zero T
|
||||
return zero
|
||||
}
|
||||
|
||||
func (r *Multirange[T]) ScanNull() error {
|
||||
*r = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Multirange[T]) SetLen(n int) error {
|
||||
*r = make([]T, n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Multirange[T]) ScanIndex(i int) any {
|
||||
return &r[i]
|
||||
}
|
||||
|
||||
func (r Multirange[T]) ScanIndexType() any {
|
||||
return new(T)
|
||||
}
|
||||
Reference in New Issue
Block a user