diff --git a/build.go b/build.go index 2ab9f21..be7679e 100644 --- a/build.go +++ b/build.go @@ -25,7 +25,7 @@ func build(k *Kong, ast interface{}) (app *Application, err error) { seenFlags[flag.Name] = true } - node, err := buildNode(k, iv, ApplicationNode, seenFlags) + node, err := buildNode(k, iv, ApplicationNode, newEmptyTag(), seenFlags) if err != nil { return nil, err } @@ -49,7 +49,7 @@ type flattenedField struct { tag *Tag } -func flattenedFields(v reflect.Value) (out []flattenedField, err error) { +func flattenedFields(v reflect.Value, ptag *Tag) (out []flattenedField, err error) { v = reflect.Indirect(v) for i := 0; i < v.NumField(); i++ { ft := v.Type().Field(i) @@ -61,6 +61,15 @@ func flattenedFields(v reflect.Value) (out []flattenedField, err error) { if tag.Ignored { continue } + // Assign group if it's not already set. + if tag.Group == "" { + tag.Group = ptag.Group + } + // Accumulate prefixes. + tag.Prefix = ptag.Prefix + tag.Prefix + tag.EnvPrefix = ptag.EnvPrefix + tag.EnvPrefix + // Combine parent vars. + tag.Vars = ptag.Vars.CloneWith(tag.Vars) // Command and embedded structs can be pointers, so we hydrate them now. if (tag.Cmd || tag.Embed) && ft.Type.Kind() == reflect.Ptr { fv = reflect.New(ft.Type.Elem()).Elem() @@ -68,7 +77,8 @@ func flattenedFields(v reflect.Value) (out []flattenedField, err error) { } if !ft.Anonymous && !tag.Embed { if fv.CanSet() { - out = append(out, flattenedField{field: ft, value: fv, tag: tag}) + field := flattenedField{field: ft, value: fv, tag: tag} + out = append(out, field) } continue } @@ -78,7 +88,7 @@ func flattenedFields(v reflect.Value) (out []flattenedField, err error) { fv = fv.Elem() } else if fv.Type() == reflect.TypeOf(Plugins{}) { for i := 0; i < fv.Len(); i++ { - fields, ferr := flattenedFields(fv.Index(i).Elem()) + fields, ferr := flattenedFields(fv.Index(i).Elem(), tag) if ferr != nil { return nil, ferr } @@ -86,21 +96,10 @@ func flattenedFields(v reflect.Value) (out []flattenedField, err error) { } continue } - sub, err := flattenedFields(fv) + sub, err := flattenedFields(fv, tag) if err != nil { return nil, err } - for _, subf := range sub { - // Assign parent if it's not already set. - if subf.tag.Group == "" { - subf.tag.Group = tag.Group - } - // Accumulate prefixes. - subf.tag.Prefix = tag.Prefix + subf.tag.Prefix - subf.tag.EnvPrefix = tag.EnvPrefix + subf.tag.EnvPrefix - // Combine parent vars. - subf.tag.Vars = tag.Vars.CloneWith(subf.tag.Vars) - } out = append(out, sub...) } return out, nil @@ -109,13 +108,13 @@ func flattenedFields(v reflect.Value) (out []flattenedField, err error) { // Build a Node in the Kong data model. // // "v" is the value to create the node from, "typ" is the output Node type. -func buildNode(k *Kong, v reflect.Value, typ NodeType, seenFlags map[string]bool) (*Node, error) { +func buildNode(k *Kong, v reflect.Value, typ NodeType, tag *Tag, seenFlags map[string]bool) (*Node, error) { node := &Node{ Type: typ, Target: v, - Tag: newEmptyTag(), + Tag: tag, } - fields, err := flattenedFields(v) + fields, err := flattenedFields(v, tag) if err != nil { return nil, err } @@ -201,7 +200,7 @@ func validatePositionalArguments(node *Node) error { } func buildChild(k *Kong, node *Node, typ NodeType, v reflect.Value, ft reflect.StructField, fv reflect.Value, tag *Tag, name string, seenFlags map[string]bool) error { - child, err := buildNode(k, fv, typ, seenFlags) + child, err := buildNode(k, fv, typ, newEmptyTag(), seenFlags) if err != nil { return err } diff --git a/callbacks.go b/callbacks.go index 4e38e17..3a8a45f 100644 --- a/callbacks.go +++ b/callbacks.go @@ -74,11 +74,14 @@ func getMethod(value reflect.Value, name string) reflect.Value { return method } -func callMethod(name string, v, f reflect.Value, bindings bindings) error { +func callFunction(f reflect.Value, bindings bindings) error { + if f.Kind() != reflect.Func { + return fmt.Errorf("expected function, got %s", f.Type()) + } in := []reflect.Value{} t := f.Type() if t.NumOut() != 1 || !t.Out(0).Implements(callbackReturnSignature) { - return fmt.Errorf("return value of %T.%s() must implement \"error\"", v.Type(), name) + return fmt.Errorf("return value of %s must implement \"error\"", t) } for i := 0; i < t.NumIn(); i++ { pt := t.In(i) @@ -89,7 +92,7 @@ func callMethod(name string, v, f reflect.Value, bindings bindings) error { } in = append(in, argv) } else { - return fmt.Errorf("couldn't find binding of type %s for parameter %d of %s.%s(), use kong.Bind(%s)", pt, i, v.Type(), name, pt) + return fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt) } } out := f.Call(in) @@ -98,3 +101,37 @@ func callMethod(name string, v, f reflect.Value, bindings bindings) error { } return out[0].Interface().(error) // nolint } + +func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error) { + if f.Kind() != reflect.Func { + return nil, fmt.Errorf("expected function, got %s", f.Type()) + } + in := []reflect.Value{} + t := f.Type() + for i := 0; i < t.NumIn(); i++ { + pt := t.In(i) + if argf, ok := bindings[pt]; ok { + argv, err := argf() + if err != nil { + return nil, err + } + in = append(in, argv) + } else { + return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt) + } + } + outv := f.Call(in) + out = make([]any, len(outv)) + for i, v := range outv { + out[i] = v.Interface() + } + return out, nil +} + +func callMethod(name string, v, f reflect.Value, bindings bindings) error { + err := callFunction(f, bindings) + if err != nil { + return fmt.Errorf("%s.%s(): %w", v.Type(), name, err) + } + return nil +} diff --git a/context.go b/context.go index de43925..de9e408 100644 --- a/context.go +++ b/context.go @@ -110,7 +110,7 @@ func (c *Context) Bind(args ...interface{}) { // // This will typically have to be called like so: // -// BindTo(impl, (*MyInterface)(nil)) +// BindTo(impl, (*MyInterface)(nil)) func (c *Context) BindTo(impl, iface interface{}) { c.bindings.addTo(impl, iface) } @@ -719,6 +719,13 @@ func (c *Context) parseFlag(flags []*Flag, match string) (err error) { return findPotentialCandidates(match, candidates, "unknown flag %s", match) } +// Call an arbitrary function filling arguments with bound values. +func (c *Context) Call(fn any, binds ...interface{}) (out []interface{}, err error) { + fv := reflect.ValueOf(fn) + bindings := c.Kong.bindings.clone().add(binds...).add(c).merge(c.bindings) //nolint:govet + return callAnyFunction(fv, bindings) +} + // RunNode calls the Run() method on an arbitrary node. // // This is useful in conjunction with Visit(), for dynamically running commands. diff --git a/kong.go b/kong.go index 19e2b85..1ae9188 100644 --- a/kong.go +++ b/kong.go @@ -68,6 +68,7 @@ type Kong struct { // Set temporarily by Options. These are applied after build(). postBuildOptions []Option + embedded []embedded dynamicCommands []*dynamicCommand } @@ -110,6 +111,25 @@ func New(grammar interface{}, options ...Option) (*Kong, error) { k.Model = model k.Model.HelpFlag = k.helpFlag + // Embed any embedded structs. + for _, embed := range k.embedded { + tag, err := parseTagString(strings.Join(embed.tags, " ")) //nolint:govet + if err != nil { + return nil, err + } + tag.Embed = true + v := reflect.Indirect(reflect.ValueOf(embed.strct)) + node, err := buildNode(k, v, CommandNode, tag, map[string]bool{}) + if err != nil { + return nil, err + } + for _, child := range node.Children { + child.Parent = k.Model.Node + k.Model.Children = append(k.Model.Children, child) + } + k.Model.Flags = append(k.Model.Flags, node.Flags...) + } + // Synthesise command nodes. for _, dcmd := range k.dynamicCommands { tag, terr := parseTagString(strings.Join(dcmd.tags, " ")) @@ -188,6 +208,10 @@ func (k *Kong) interpolateValue(value *Value, vars Vars) (err error) { vars = vars.CloneWith(varsContributor.Vars(value)) } + if value.Enum, err = interpolate(value.Enum, vars, nil); err != nil { + return fmt.Errorf("enum for %s: %s", value.Summary(), err) + } + updatedVars := map[string]string{ "default": value.Default, "enum": value.Enum, diff --git a/options.go b/options.go index 02d33fd..ec724b0 100644 --- a/options.go +++ b/options.go @@ -55,6 +55,25 @@ func Exit(exit func(int)) Option { }) } +type embedded struct { + strct any + tags []string +} + +// Embed a struct into the root of the CLI. +// +// "strct" must be a pointer to a structure. +func Embed(strct any, tags ...string) Option { + t := reflect.TypeOf(strct) + if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct { + panic("kong: Embed() must be called with a pointer to a struct") + } + return OptionFunc(func(k *Kong) error { + k.embedded = append(k.embedded, embedded{strct, tags}) + return nil + }) +} + type dynamicCommand struct { name string help string @@ -164,8 +183,8 @@ func Writers(stdout, stderr io.Writer) Option { // // There are two hook points: // -// BeforeApply(...) error -// AfterApply(...) error +// BeforeApply(...) error +// AfterApply(...) error // // Called before validation/assignment, and immediately after validation/assignment, respectively. func Bind(args ...interface{}) Option { @@ -177,7 +196,7 @@ func Bind(args ...interface{}) Option { // BindTo allows binding of implementations to interfaces. // -// BindTo(impl, (*iface)(nil)) +// BindTo(impl, (*iface)(nil)) func BindTo(impl, iface interface{}) Option { return OptionFunc(func(k *Kong) error { k.bindings.addTo(impl, iface) @@ -428,7 +447,8 @@ func siftStrings(ss []string, filter func(s string) bool) []string { // Predefined environment variables are skipped. // // For example: -// --some.value -> PREFIX_SOME_VALUE +// +// --some.value -> PREFIX_SOME_VALUE func DefaultEnvars(prefix string) Option { processFlag := func(flag *Flag) { switch env := flag.Env; { diff --git a/options_test.go b/options_test.go index e79168b..a77853a 100644 --- a/options_test.go +++ b/options_test.go @@ -58,7 +58,7 @@ func TestInvalidCallback(t *testing.T) { p, err := New(&cli, BindTo(impl("foo"), (*iface)(nil))) assert.NoError(t, err) err = callMethod("method", reflect.ValueOf(impl("??")), reflect.ValueOf(method), p.bindings) - assert.EqualError(t, err, `return value of *reflect.rtype.method() must implement "error"`) + assert.EqualError(t, err, `kong.impl.method(): return value of func(kong.iface) string must implement "error"`) } type zrror struct{} diff --git a/tag.go b/tag.go index b471613..6a94b2d 100644 --- a/tag.go +++ b/tag.go @@ -44,6 +44,16 @@ type Tag struct { items map[string][]string } +func (t *Tag) String() string { + out := []string{} + for key, list := range t.items { + for _, value := range list { + out = append(out, fmt.Sprintf("%s:%q", key, value)) + } + } + return strings.Join(out, " ") +} + type tagChars struct { sep, quote, assign rune }