feat: allow hooks to be declared on embedded fields
Specifically, on Go embedded fields, not on fields tagged with `embed`. Fixes #90.
This commit is contained in:
@@ -307,8 +307,8 @@ func main() {
|
|||||||
|
|
||||||
## Hooks: BeforeReset(), BeforeResolve(), BeforeApply(), AfterApply() and the Bind() option
|
## Hooks: BeforeReset(), BeforeResolve(), BeforeApply(), AfterApply() and the Bind() option
|
||||||
|
|
||||||
If a node in the grammar has a `BeforeReset(...)`, `BeforeResolve
|
If a node in the CLI, or any of its embedded fields, has a `BeforeReset(...) error`, `BeforeResolve
|
||||||
(...)`, `BeforeApply(...) error` and/or `AfterApply(...) error` method, those
|
(...) error`, `BeforeApply(...) error` and/or `AfterApply(...) error` method, those
|
||||||
methods will be called before values are reset, before validation/assignment,
|
methods will be called before values are reset, before validation/assignment,
|
||||||
and after validation/assignment, respectively.
|
and after validation/assignment, respectively.
|
||||||
|
|
||||||
@@ -341,40 +341,6 @@ func main() {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Another example of using hooks is load the env-file:
|
|
||||||
|
|
||||||
```go
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"github.com/alecthomas/kong"
|
|
||||||
"github.com/joho/godotenv"
|
|
||||||
)
|
|
||||||
|
|
||||||
type EnvFlag string
|
|
||||||
|
|
||||||
// BeforeResolve loads env file.
|
|
||||||
func (c EnvFlag) BeforeReset(ctx *kong.Context, trace *kong.Path) error {
|
|
||||||
path := string(ctx.FlagValue(trace.Flag).(EnvFlag)) // nolint
|
|
||||||
path = kong.ExpandPath(path)
|
|
||||||
if err := godotenv.Load(path); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var CLI struct {
|
|
||||||
EnvFile EnvFlag
|
|
||||||
Flag `env:"FLAG"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
_ = kong.Parse(&CLI)
|
|
||||||
fmt.Println(CLI.Flag)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Flags
|
## Flags
|
||||||
|
|
||||||
Any [mapped](#mapper---customising-how-the-command-line-is-mapped-to-go-values) field in the command structure _not_ tagged with `cmd` or `arg` will be a flag. Flags are optional by default.
|
Any [mapped](#mapper---customising-how-the-command-line-is-mapped-to-go-values) field in the command structure _not_ tagged with `cmd` or `arg` will be a flag. Flags are optional by default.
|
||||||
|
|||||||
@@ -68,6 +68,33 @@ func getMethod(value reflect.Value, name string) reflect.Value {
|
|||||||
return method
|
return method
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get methods from the given value and any embedded fields.
|
||||||
|
func getMethods(value reflect.Value, name string) []reflect.Value {
|
||||||
|
// Collect all possible receivers
|
||||||
|
receivers := []reflect.Value{value}
|
||||||
|
if value.Kind() == reflect.Ptr {
|
||||||
|
value = value.Elem()
|
||||||
|
}
|
||||||
|
if value.Kind() == reflect.Struct {
|
||||||
|
t := value.Type()
|
||||||
|
for i := 0; i < value.NumField(); i++ {
|
||||||
|
field := value.Field(i)
|
||||||
|
fieldType := t.Field(i)
|
||||||
|
if fieldType.IsExported() && fieldType.Anonymous {
|
||||||
|
receivers = append(receivers, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Search all receivers for methods
|
||||||
|
var methods []reflect.Value
|
||||||
|
for _, receiver := range receivers {
|
||||||
|
if method := getMethod(receiver, name); method.IsValid() {
|
||||||
|
methods = append(methods, method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return methods
|
||||||
|
}
|
||||||
|
|
||||||
func callFunction(f reflect.Value, bindings bindings) error {
|
func callFunction(f reflect.Value, bindings bindings) error {
|
||||||
if f.Kind() != reflect.Func {
|
if f.Kind() != reflect.Func {
|
||||||
return fmt.Errorf("expected function, got %s", f.Type())
|
return fmt.Errorf("expected function, got %s", f.Type())
|
||||||
|
|||||||
@@ -361,16 +361,14 @@ func (k *Kong) applyHook(ctx *Context, name string) error {
|
|||||||
default:
|
default:
|
||||||
panic("unsupported Path")
|
panic("unsupported Path")
|
||||||
}
|
}
|
||||||
method := getMethod(value, name)
|
for _, method := range getMethods(value, name) {
|
||||||
if !method.IsValid() {
|
binds := k.bindings.clone()
|
||||||
continue
|
binds.add(ctx, trace)
|
||||||
}
|
binds.add(trace.Node().Vars().CloneWith(k.vars))
|
||||||
binds := k.bindings.clone()
|
binds.merge(ctx.bindings)
|
||||||
binds.add(ctx, trace)
|
if err := callFunction(method, binds); err != nil {
|
||||||
binds.add(trace.Node().Vars().CloneWith(k.vars))
|
return err
|
||||||
binds.merge(ctx.bindings)
|
}
|
||||||
if err := callFunction(method, binds); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Path[0] will always be the app root.
|
// Path[0] will always be the app root.
|
||||||
@@ -392,13 +390,11 @@ func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) er
|
|||||||
if !flag.HasDefault || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() {
|
if !flag.HasDefault || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
method := getMethod(flag.Target, name)
|
for _, method := range getMethods(flag.Target, name) {
|
||||||
if !method.IsValid() {
|
path := &Path{Flag: flag}
|
||||||
continue
|
if err := callFunction(method, binds.clone().add(path)); err != nil {
|
||||||
}
|
return next(err)
|
||||||
path := &Path{Flag: flag}
|
}
|
||||||
if err := callFunction(method, binds.clone().add(path)); err != nil {
|
|
||||||
return next(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return next(nil)
|
return next(nil)
|
||||||
|
|||||||
@@ -2406,3 +2406,36 @@ func TestProviderMethods(t *testing.T) {
|
|||||||
err = kctx.Run(t)
|
err = kctx.Run(t)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddedCallback struct {
|
||||||
|
Embedded bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EmbeddedCallback) AfterApply() error {
|
||||||
|
e.Embedded = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddedRoot struct {
|
||||||
|
EmbeddedCallback
|
||||||
|
Root bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EmbeddedRoot) AfterApply() error {
|
||||||
|
e.Root = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedCallbacks(t *testing.T) {
|
||||||
|
actual := &EmbeddedRoot{}
|
||||||
|
k := mustNew(t, actual)
|
||||||
|
_, err := k.Parse(nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
expected := &EmbeddedRoot{
|
||||||
|
EmbeddedCallback: EmbeddedCallback{
|
||||||
|
Embedded: true,
|
||||||
|
},
|
||||||
|
Root: true,
|
||||||
|
}
|
||||||
|
assert.Equal(t, expected, actual)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user