From 9b08b8939602e1a55972e5543d64a017feb13daf Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Sat, 26 Nov 2022 11:20:55 +1100 Subject: [PATCH] fix: ensure pointers can be detected as bools This required adding a BoolMapperExt interface. --- mapper.go | 64 +++++++++++++++++++++++++++++++++++++++---------------- model.go | 3 +++ 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/mapper.go b/mapper.go index 65f9814..06cb71f 100644 --- a/mapper.go +++ b/mapper.go @@ -123,9 +123,16 @@ type VarsContributor interface { // // This is used solely for formatting help. type BoolMapper interface { + Mapper IsBool() bool } +// BoolMapperExt allows a Mapper to dynamically determine if a value is a boolean. +type BoolMapperExt interface { + Mapper + IsBoolFromValue(v reflect.Value) bool +} + // A MapperFunc is a single function that complies with the Mapper interface. type MapperFunc func(ctx *DecodeContext, target reflect.Value) error @@ -186,6 +193,7 @@ func (r *Registry) ForType(typ reflect.Type) Mapper { // Check if the type implements MapperValue. for _, impl := range []reflect.Type{typ, reflect.PtrTo(typ)} { if impl.Implements(mapperValueType) { + // FIXME: This should pass in the bool mapper. return &mapperValueAdapter{impl.Implements(boolMapperType)} } } @@ -223,8 +231,8 @@ func (r *Registry) RegisterKind(kind reflect.Kind, mapper Mapper) *Registry { // // eg. // -// Mapper string `kong:"type='colour'` -// registry.RegisterName("colour", ...) +// Mapper string `kong:"type='colour'` +// registry.RegisterName("colour", ...) func (r *Registry) RegisterName(name string, mapper Mapper) *Registry { r.names[name] = mapper return r @@ -275,7 +283,7 @@ func (r *Registry) RegisterDefaults() *Registry { RegisterName("existingfile", existingFileMapper(r)). RegisterName("existingdir", existingDirMapper(r)). RegisterName("counter", counterMapper()). - RegisterKind(reflect.Ptr, ptrMapper(r)) + RegisterKind(reflect.Ptr, ptrMapper{r}) } type boolMapper struct{} @@ -676,20 +684,40 @@ func existingDirMapper(r *Registry) MapperFunc { } } -func ptrMapper(r *Registry) MapperFunc { - return func(ctx *DecodeContext, target reflect.Value) error { - elem := reflect.New(target.Type().Elem()).Elem() - nestedMapper := r.ForValue(elem) - if nestedMapper == nil { - return fmt.Errorf("cannot find mapper for %v", target.Type().Elem().String()) - } - err := nestedMapper.Decode(ctx, elem) - if err != nil { - return err - } - target.Set(elem.Addr()) - return nil +type ptrMapper struct { + r *Registry +} + +var _ BoolMapperExt = (*ptrMapper)(nil) + +// IsBoolFromValue implements BoolMapperExt +func (p ptrMapper) IsBoolFromValue(target reflect.Value) bool { + elem := reflect.New(target.Type().Elem()).Elem() + nestedMapper := p.r.ForValue(elem) + if nestedMapper == nil { + return false } + if bm, ok := nestedMapper.(BoolMapper); ok && bm.IsBool() { + return true + } + if bm, ok := nestedMapper.(BoolMapperExt); ok && bm.IsBoolFromValue(target) { + return true + } + return target.Kind() == reflect.Ptr && target.Type().Elem().Kind() == reflect.Bool +} + +func (p ptrMapper) Decode(ctx *DecodeContext, target reflect.Value) error { + elem := reflect.New(target.Type().Elem()).Elem() + nestedMapper := p.r.ForValue(elem) + if nestedMapper == nil { + return fmt.Errorf("cannot find mapper for %v", target.Type().Elem().String()) + } + err := nestedMapper.Decode(ctx, elem) + if err != nil { + return err + } + target.Set(elem.Addr()) + return nil } func counterMapper() MapperFunc { @@ -753,7 +781,7 @@ func urlMapper() MapperFunc { // // It differs from strings.Split() in that the separator can exist in a field by escaping it with a \. eg. // -// SplitEscaped(`hello\,there,bob`, ',') == []string{"hello,there", "bob"} +// SplitEscaped(`hello\,there,bob`, ',') == []string{"hello,there", "bob"} func SplitEscaped(s string, sep rune) (out []string) { if sep == -1 { return []string{s} @@ -786,7 +814,7 @@ func SplitEscaped(s string, sep rune) (out []string) { // JoinEscaped joins a slice of strings on sep, but also escapes any instances of sep in the fields with \. eg. // -// JoinEscaped([]string{"hello,there", "bob"}, ',') == `hello\,there,bob` +// JoinEscaped([]string{"hello,there", "bob"}, ',') == `hello\,there,bob` func JoinEscaped(s []string, sep rune) string { escaped := []string{} for _, e := range s { diff --git a/model.go b/model.go index 5879a4b..1428965 100644 --- a/model.go +++ b/model.go @@ -318,6 +318,9 @@ func (v *Value) IsMap() bool { // IsBool returns true if the underlying value is a boolean. func (v *Value) IsBool() bool { + if m, ok := v.Mapper.(BoolMapperExt); ok && m.IsBoolFromValue(v.Target) { + return true + } if m, ok := v.Mapper.(BoolMapper); ok && m.IsBool() { return true }