feat: allow hooks to be declared on embedded fields

Specifically, on Go embedded fields, not on fields tagged with `embed`.

Fixes #90.
This commit is contained in:
Alec Thomas
2024-12-27 13:41:00 +11:00
parent 565ae9b740
commit 840220c2ed
4 changed files with 75 additions and 53 deletions
+2 -36
View File
@@ -307,8 +307,8 @@ func main() {
## Hooks: BeforeReset(), BeforeResolve(), BeforeApply(), AfterApply() and the Bind() option
If a node in the grammar has a `BeforeReset(...)`, `BeforeResolve
(...)`, `BeforeApply(...) error` and/or `AfterApply(...) error` method, those
If a node in the CLI, or any of its embedded fields, has a `BeforeReset(...) error`, `BeforeResolve
(...) error`, `BeforeApply(...) error` and/or `AfterApply(...) error` method, those
methods will be called before values are reset, before validation/assignment,
and after validation/assignment, respectively.
@@ -341,40 +341,6 @@ func main() {
}
```
Another example of using hooks is load the env-file:
```go
package main
import (
"fmt"
"github.com/alecthomas/kong"
"github.com/joho/godotenv"
)
type EnvFlag string
// BeforeResolve loads env file.
func (c EnvFlag) BeforeReset(ctx *kong.Context, trace *kong.Path) error {
path := string(ctx.FlagValue(trace.Flag).(EnvFlag)) // nolint
path = kong.ExpandPath(path)
if err := godotenv.Load(path); err != nil {
return err
}
return nil
}
var CLI struct {
EnvFile EnvFlag
Flag `env:"FLAG"`
}
func main() {
_ = kong.Parse(&CLI)
fmt.Println(CLI.Flag)
}
```
## Flags
Any [mapped](#mapper---customising-how-the-command-line-is-mapped-to-go-values) field in the command structure _not_ tagged with `cmd` or `arg` will be a flag. Flags are optional by default.
+27
View File
@@ -68,6 +68,33 @@ func getMethod(value reflect.Value, name string) reflect.Value {
return method
}
// Get methods from the given value and any embedded fields.
func getMethods(value reflect.Value, name string) []reflect.Value {
// Collect all possible receivers
receivers := []reflect.Value{value}
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.Kind() == reflect.Struct {
t := value.Type()
for i := 0; i < value.NumField(); i++ {
field := value.Field(i)
fieldType := t.Field(i)
if fieldType.IsExported() && fieldType.Anonymous {
receivers = append(receivers, field)
}
}
}
// Search all receivers for methods
var methods []reflect.Value
for _, receiver := range receivers {
if method := getMethod(receiver, name); method.IsValid() {
methods = append(methods, method)
}
}
return methods
}
func callFunction(f reflect.Value, bindings bindings) error {
if f.Kind() != reflect.Func {
return fmt.Errorf("expected function, got %s", f.Type())
+13 -17
View File
@@ -361,16 +361,14 @@ func (k *Kong) applyHook(ctx *Context, name string) error {
default:
panic("unsupported Path")
}
method := getMethod(value, name)
if !method.IsValid() {
continue
}
binds := k.bindings.clone()
binds.add(ctx, trace)
binds.add(trace.Node().Vars().CloneWith(k.vars))
binds.merge(ctx.bindings)
if err := callFunction(method, binds); err != nil {
return err
for _, method := range getMethods(value, name) {
binds := k.bindings.clone()
binds.add(ctx, trace)
binds.add(trace.Node().Vars().CloneWith(k.vars))
binds.merge(ctx.bindings)
if err := callFunction(method, binds); err != nil {
return err
}
}
}
// Path[0] will always be the app root.
@@ -392,13 +390,11 @@ func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) er
if !flag.HasDefault || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() {
continue
}
method := getMethod(flag.Target, name)
if !method.IsValid() {
continue
}
path := &Path{Flag: flag}
if err := callFunction(method, binds.clone().add(path)); err != nil {
return next(err)
for _, method := range getMethods(flag.Target, name) {
path := &Path{Flag: flag}
if err := callFunction(method, binds.clone().add(path)); err != nil {
return next(err)
}
}
}
return next(nil)
+33
View File
@@ -2406,3 +2406,36 @@ func TestProviderMethods(t *testing.T) {
err = kctx.Run(t)
assert.NoError(t, err)
}
type EmbeddedCallback struct {
Embedded bool
}
func (e *EmbeddedCallback) AfterApply() error {
e.Embedded = true
return nil
}
type EmbeddedRoot struct {
EmbeddedCallback
Root bool
}
func (e *EmbeddedRoot) AfterApply() error {
e.Root = true
return nil
}
func TestEmbeddedCallbacks(t *testing.T) {
actual := &EmbeddedRoot{}
k := mustNew(t, actual)
_, err := k.Parse(nil)
assert.NoError(t, err)
expected := &EmbeddedRoot{
EmbeddedCallback: EmbeddedCallback{
Embedded: true,
},
Root: true,
}
assert.Equal(t, expected, actual)
}