diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 75934ced..3244b504 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -344,10 +344,12 @@ func NewMap() *Map { registerDefaultPgTypeVariants("int8", "_int8", int64(0)) // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants("int8", "_int8", int8(0)) + registerDefaultPgTypeVariants("int8", "_int8", int(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint8(0)) registerDefaultPgTypeVariants("int8", "_int8", uint16(0)) registerDefaultPgTypeVariants("int8", "_int8", uint32(0)) registerDefaultPgTypeVariants("int8", "_int8", uint64(0)) - registerDefaultPgTypeVariants("int8", "_int8", int(0)) registerDefaultPgTypeVariants("int8", "_int8", uint(0)) registerDefaultPgTypeVariants("float4", "_float4", float32(0)) @@ -355,12 +357,46 @@ func NewMap() *Map { registerDefaultPgTypeVariants("bool", "_bool", false) registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{}) + registerDefaultPgTypeVariants("interval", "_interval", time.Duration(0)) registerDefaultPgTypeVariants("text", "_text", "") registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil)) registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) registerDefaultPgTypeVariants("cidr", "_cidr", net.IPNet{}) + // pgtype provided structs + registerDefaultPgTypeVariants("varbit", "_varbit", Bits{}) + registerDefaultPgTypeVariants("bool", "_bool", Bool{}) + registerDefaultPgTypeVariants("box", "_box", Box{}) + registerDefaultPgTypeVariants("circle", "_circle", Circle{}) + registerDefaultPgTypeVariants("date", "_date", Date{}) + registerDefaultPgTypeVariants("daterange", "_daterange", Daterange{}) + registerDefaultPgTypeVariants("float4", "_float4", Float4{}) + registerDefaultPgTypeVariants("float8", "_float8", Float8{}) + registerDefaultPgTypeVariants("float8range", "_float8range", Float8range{}) + registerDefaultPgTypeVariants("inet", "_inet", Inet{}) + registerDefaultPgTypeVariants("int2", "_int2", Int2{}) + registerDefaultPgTypeVariants("int4", "_int4", Int4{}) + registerDefaultPgTypeVariants("int4range", "_int4range", Int4range{}) + registerDefaultPgTypeVariants("int8", "_int8", Int8{}) + registerDefaultPgTypeVariants("int8range", "_int8range", Int8range{}) + registerDefaultPgTypeVariants("interval", "_interval", Interval{}) + registerDefaultPgTypeVariants("line", "_line", Line{}) + registerDefaultPgTypeVariants("lseg", "_lseg", Lseg{}) + registerDefaultPgTypeVariants("numeric", "_numeric", Numeric{}) + registerDefaultPgTypeVariants("numrange", "_numrange", Numrange{}) + registerDefaultPgTypeVariants("path", "_path", Path{}) + registerDefaultPgTypeVariants("point", "_point", Point{}) + registerDefaultPgTypeVariants("polygon", "_polygon", Polygon{}) + registerDefaultPgTypeVariants("tid", "_tid", TID{}) + registerDefaultPgTypeVariants("text", "_text", Text{}) + registerDefaultPgTypeVariants("time", "_time", Time{}) + registerDefaultPgTypeVariants("timestamp", "_timestamp", Timestamp{}) + registerDefaultPgTypeVariants("timestamptz", "_timestamptz", Timestamptz{}) + registerDefaultPgTypeVariants("tsrange", "_tsrange", Tsrange{}) + registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Tstzrange{}) + registerDefaultPgTypeVariants("uuid", "_uuid", UUID{}) + return m } @@ -1181,13 +1217,13 @@ func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan if plan := dt.Codec.PlanEncode(m, oid, format, value); plan != nil { return plan } + } - for _, f := range m.TryWrapEncodePlanFuncs { - if wrapperPlan, nextValue, ok := f(value); ok { - if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan - } + for _, f := range m.TryWrapEncodePlanFuncs { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan } } } diff --git a/values.go b/values.go index 766074bd..0f34b6a6 100644 --- a/values.go +++ b/values.go @@ -30,9 +30,13 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) return nil, nil } + if dv, ok := arg.(driver.Valuer); ok { + return dv.Value() + } + + // All these could be handled by m.Encode below. However, that transforms the argument to a string. That could change + // the type of the argument. e.g. '42' instead of 42. So standard types are special cased. switch arg := arg.(type) { - case driver.Valuer: - return arg.Value() case float32: return float64(arg), nil case float64: