feat: allow hooks to be declared on embedded fields
Specifically, on Go embedded fields, not on fields tagged with `embed`. Fixes #90.
This commit is contained in:
@@ -307,8 +307,8 @@ func main() {
|
||||
|
||||
## Hooks: BeforeReset(), BeforeResolve(), BeforeApply(), AfterApply() and the Bind() option
|
||||
|
||||
If a node in the grammar has a `BeforeReset(...)`, `BeforeResolve
|
||||
(...)`, `BeforeApply(...) error` and/or `AfterApply(...) error` method, those
|
||||
If a node in the CLI, or any of its embedded fields, has a `BeforeReset(...) error`, `BeforeResolve
|
||||
(...) error`, `BeforeApply(...) error` and/or `AfterApply(...) error` method, those
|
||||
methods will be called before values are reset, before validation/assignment,
|
||||
and after validation/assignment, respectively.
|
||||
|
||||
@@ -341,40 +341,6 @@ func main() {
|
||||
}
|
||||
```
|
||||
|
||||
Another example of using hooks is load the env-file:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/alecthomas/kong"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
type EnvFlag string
|
||||
|
||||
// BeforeResolve loads env file.
|
||||
func (c EnvFlag) BeforeReset(ctx *kong.Context, trace *kong.Path) error {
|
||||
path := string(ctx.FlagValue(trace.Flag).(EnvFlag)) // nolint
|
||||
path = kong.ExpandPath(path)
|
||||
if err := godotenv.Load(path); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var CLI struct {
|
||||
EnvFile EnvFlag
|
||||
Flag `env:"FLAG"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
_ = kong.Parse(&CLI)
|
||||
fmt.Println(CLI.Flag)
|
||||
}
|
||||
```
|
||||
|
||||
## Flags
|
||||
|
||||
Any [mapped](#mapper---customising-how-the-command-line-is-mapped-to-go-values) field in the command structure _not_ tagged with `cmd` or `arg` will be a flag. Flags are optional by default.
|
||||
|
||||
@@ -68,6 +68,33 @@ func getMethod(value reflect.Value, name string) reflect.Value {
|
||||
return method
|
||||
}
|
||||
|
||||
// Get methods from the given value and any embedded fields.
|
||||
func getMethods(value reflect.Value, name string) []reflect.Value {
|
||||
// Collect all possible receivers
|
||||
receivers := []reflect.Value{value}
|
||||
if value.Kind() == reflect.Ptr {
|
||||
value = value.Elem()
|
||||
}
|
||||
if value.Kind() == reflect.Struct {
|
||||
t := value.Type()
|
||||
for i := 0; i < value.NumField(); i++ {
|
||||
field := value.Field(i)
|
||||
fieldType := t.Field(i)
|
||||
if fieldType.IsExported() && fieldType.Anonymous {
|
||||
receivers = append(receivers, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Search all receivers for methods
|
||||
var methods []reflect.Value
|
||||
for _, receiver := range receivers {
|
||||
if method := getMethod(receiver, name); method.IsValid() {
|
||||
methods = append(methods, method)
|
||||
}
|
||||
}
|
||||
return methods
|
||||
}
|
||||
|
||||
func callFunction(f reflect.Value, bindings bindings) error {
|
||||
if f.Kind() != reflect.Func {
|
||||
return fmt.Errorf("expected function, got %s", f.Type())
|
||||
|
||||
@@ -361,16 +361,14 @@ func (k *Kong) applyHook(ctx *Context, name string) error {
|
||||
default:
|
||||
panic("unsupported Path")
|
||||
}
|
||||
method := getMethod(value, name)
|
||||
if !method.IsValid() {
|
||||
continue
|
||||
}
|
||||
binds := k.bindings.clone()
|
||||
binds.add(ctx, trace)
|
||||
binds.add(trace.Node().Vars().CloneWith(k.vars))
|
||||
binds.merge(ctx.bindings)
|
||||
if err := callFunction(method, binds); err != nil {
|
||||
return err
|
||||
for _, method := range getMethods(value, name) {
|
||||
binds := k.bindings.clone()
|
||||
binds.add(ctx, trace)
|
||||
binds.add(trace.Node().Vars().CloneWith(k.vars))
|
||||
binds.merge(ctx.bindings)
|
||||
if err := callFunction(method, binds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
// Path[0] will always be the app root.
|
||||
@@ -392,13 +390,11 @@ func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) er
|
||||
if !flag.HasDefault || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() {
|
||||
continue
|
||||
}
|
||||
method := getMethod(flag.Target, name)
|
||||
if !method.IsValid() {
|
||||
continue
|
||||
}
|
||||
path := &Path{Flag: flag}
|
||||
if err := callFunction(method, binds.clone().add(path)); err != nil {
|
||||
return next(err)
|
||||
for _, method := range getMethods(flag.Target, name) {
|
||||
path := &Path{Flag: flag}
|
||||
if err := callFunction(method, binds.clone().add(path)); err != nil {
|
||||
return next(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return next(nil)
|
||||
|
||||
@@ -2406,3 +2406,36 @@ func TestProviderMethods(t *testing.T) {
|
||||
err = kctx.Run(t)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type EmbeddedCallback struct {
|
||||
Embedded bool
|
||||
}
|
||||
|
||||
func (e *EmbeddedCallback) AfterApply() error {
|
||||
e.Embedded = true
|
||||
return nil
|
||||
}
|
||||
|
||||
type EmbeddedRoot struct {
|
||||
EmbeddedCallback
|
||||
Root bool
|
||||
}
|
||||
|
||||
func (e *EmbeddedRoot) AfterApply() error {
|
||||
e.Root = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestEmbeddedCallbacks(t *testing.T) {
|
||||
actual := &EmbeddedRoot{}
|
||||
k := mustNew(t, actual)
|
||||
_, err := k.Parse(nil)
|
||||
assert.NoError(t, err)
|
||||
expected := &EmbeddedRoot{
|
||||
EmbeddedCallback: EmbeddedCallback{
|
||||
Embedded: true,
|
||||
},
|
||||
Root: true,
|
||||
}
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user