fix: ensure pointers can be detected as bools

This required adding a BoolMapperExt interface.
This commit is contained in:
Alec Thomas
2022-11-26 11:20:55 +11:00
parent bf0cbf5d7c
commit 9b08b89396
2 changed files with 49 additions and 18 deletions
+46 -18
View File
@@ -123,9 +123,16 @@ type VarsContributor interface {
//
// This is used solely for formatting help.
type BoolMapper interface {
Mapper
IsBool() bool
}
// BoolMapperExt allows a Mapper to dynamically determine if a value is a boolean.
type BoolMapperExt interface {
Mapper
IsBoolFromValue(v reflect.Value) bool
}
// A MapperFunc is a single function that complies with the Mapper interface.
type MapperFunc func(ctx *DecodeContext, target reflect.Value) error
@@ -186,6 +193,7 @@ func (r *Registry) ForType(typ reflect.Type) Mapper {
// Check if the type implements MapperValue.
for _, impl := range []reflect.Type{typ, reflect.PtrTo(typ)} {
if impl.Implements(mapperValueType) {
// FIXME: This should pass in the bool mapper.
return &mapperValueAdapter{impl.Implements(boolMapperType)}
}
}
@@ -223,8 +231,8 @@ func (r *Registry) RegisterKind(kind reflect.Kind, mapper Mapper) *Registry {
//
// eg.
//
// Mapper string `kong:"type='colour'`
// registry.RegisterName("colour", ...)
// Mapper string `kong:"type='colour'`
// registry.RegisterName("colour", ...)
func (r *Registry) RegisterName(name string, mapper Mapper) *Registry {
r.names[name] = mapper
return r
@@ -275,7 +283,7 @@ func (r *Registry) RegisterDefaults() *Registry {
RegisterName("existingfile", existingFileMapper(r)).
RegisterName("existingdir", existingDirMapper(r)).
RegisterName("counter", counterMapper()).
RegisterKind(reflect.Ptr, ptrMapper(r))
RegisterKind(reflect.Ptr, ptrMapper{r})
}
type boolMapper struct{}
@@ -676,20 +684,40 @@ func existingDirMapper(r *Registry) MapperFunc {
}
}
func ptrMapper(r *Registry) MapperFunc {
return func(ctx *DecodeContext, target reflect.Value) error {
elem := reflect.New(target.Type().Elem()).Elem()
nestedMapper := r.ForValue(elem)
if nestedMapper == nil {
return fmt.Errorf("cannot find mapper for %v", target.Type().Elem().String())
}
err := nestedMapper.Decode(ctx, elem)
if err != nil {
return err
}
target.Set(elem.Addr())
return nil
type ptrMapper struct {
r *Registry
}
var _ BoolMapperExt = (*ptrMapper)(nil)
// IsBoolFromValue implements BoolMapperExt
func (p ptrMapper) IsBoolFromValue(target reflect.Value) bool {
elem := reflect.New(target.Type().Elem()).Elem()
nestedMapper := p.r.ForValue(elem)
if nestedMapper == nil {
return false
}
if bm, ok := nestedMapper.(BoolMapper); ok && bm.IsBool() {
return true
}
if bm, ok := nestedMapper.(BoolMapperExt); ok && bm.IsBoolFromValue(target) {
return true
}
return target.Kind() == reflect.Ptr && target.Type().Elem().Kind() == reflect.Bool
}
func (p ptrMapper) Decode(ctx *DecodeContext, target reflect.Value) error {
elem := reflect.New(target.Type().Elem()).Elem()
nestedMapper := p.r.ForValue(elem)
if nestedMapper == nil {
return fmt.Errorf("cannot find mapper for %v", target.Type().Elem().String())
}
err := nestedMapper.Decode(ctx, elem)
if err != nil {
return err
}
target.Set(elem.Addr())
return nil
}
func counterMapper() MapperFunc {
@@ -753,7 +781,7 @@ func urlMapper() MapperFunc {
//
// It differs from strings.Split() in that the separator can exist in a field by escaping it with a \. eg.
//
// SplitEscaped(`hello\,there,bob`, ',') == []string{"hello,there", "bob"}
// SplitEscaped(`hello\,there,bob`, ',') == []string{"hello,there", "bob"}
func SplitEscaped(s string, sep rune) (out []string) {
if sep == -1 {
return []string{s}
@@ -786,7 +814,7 @@ func SplitEscaped(s string, sep rune) (out []string) {
// JoinEscaped joins a slice of strings on sep, but also escapes any instances of sep in the fields with \. eg.
//
// JoinEscaped([]string{"hello,there", "bob"}, ',') == `hello\,there,bob`
// JoinEscaped([]string{"hello,there", "bob"}, ',') == `hello\,there,bob`
func JoinEscaped(s []string, sep rune) string {
escaped := []string{}
for _, e := range s {
+3
View File
@@ -318,6 +318,9 @@ func (v *Value) IsMap() bool {
// IsBool returns true if the underlying value is a boolean.
func (v *Value) IsBool() bool {
if m, ok := v.Mapper.(BoolMapperExt); ok && m.IsBoolFromValue(v.Target) {
return true
}
if m, ok := v.Mapper.(BoolMapper); ok && m.IsBool() {
return true
}