Files
kong/callbacks.go
T
Abhinav Gupta 9c08a58eb2 Support hooks on embed:"" fields (#493)
Relates to 840220c (#90)

This change adds support for hooks to be called on fields
that are tagged with `embed:""`.

### Use case

If a command has several subcommands,
many (but not all) of which need the same external resource,
this allows defining the flag-level inputs for that resource centrally,
and then using `embed:""` in any command that needs that resource.

For example, imagine:

```go
type githubClientProvider struct {
    Token string `name:"github-token" env:"GITHUB_TOKEN"`
    URL   string `name:"github-url" env:"GITHUB_URL"`
}

func (g *githubClientProvider) BeforeApply(kctx *kong.Context) error {
  return kctx.BindToProvider(func() (*github.Client, error) {
    return github.NewClient(...), nil
  })
}
```

Then, any command that needs GitHub client will add this field,
any other resource providers it needs,
and add parameters to its `Run` method to accept those resources:

```go
type listUsersCmd struct {
    GitHub githubClientProvider `embed:""`
    S3     s3ClientProvider     `embed:""`
}

func (l *listUsersCmd) Run(gh *github.Client, s3 *s3.Client) error {
    ...
}
```

### Alternatives

It is possible to do the same today if the `*Provider` struct above
is actually a Go embed instead of a Kong embed, *and* it is exported.

```
type GitHubClientProvider struct{ ... }

type listUsersCmd struct {
    GithubClientProvider
    S3ClientProvider
}
```

The difference is whether the struct defining the flags
is required to be exported or not.
2025-01-29 16:04:52 +11:00

158 lines
4.0 KiB
Go

package kong
import (
"fmt"
"reflect"
"strings"
)
// A map of type to function that returns a value of that type.
//
// The function should have the signature func(...) (T, error). Arguments are recursively resolved.
type bindings map[reflect.Type]any
func (b bindings) String() string {
out := []string{}
for k := range b {
out = append(out, k.String())
}
return "bindings{" + strings.Join(out, ", ") + "}"
}
func (b bindings) add(values ...any) bindings {
for _, v := range values {
v := v
b[reflect.TypeOf(v)] = func() (any, error) { return v, nil }
}
return b
}
func (b bindings) addTo(impl, iface any) {
b[reflect.TypeOf(iface).Elem()] = func() (any, error) { return impl, nil }
}
func (b bindings) addProvider(provider any) error {
pv := reflect.ValueOf(provider)
t := pv.Type()
if t.Kind() != reflect.Func || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
return fmt.Errorf("%T must be a function with the signature func(...)(T, error)", provider)
}
rt := pv.Type().Out(0)
b[rt] = provider
return nil
}
// Clone and add values.
func (b bindings) clone() bindings {
out := make(bindings, len(b))
for k, v := range b {
out[k] = v
}
return out
}
func (b bindings) merge(other bindings) bindings {
for k, v := range other {
b[k] = v
}
return b
}
func getMethod(value reflect.Value, name string) reflect.Value {
method := value.MethodByName(name)
if !method.IsValid() {
if value.CanAddr() {
method = value.Addr().MethodByName(name)
}
}
return method
}
// Get methods from the given value and any embedded fields.
func getMethods(value reflect.Value, name string) []reflect.Value {
// Collect all possible receivers
receivers := []reflect.Value{value}
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.Kind() == reflect.Struct {
t := value.Type()
for i := 0; i < value.NumField(); i++ {
field := value.Field(i)
fieldType := t.Field(i)
if !fieldType.IsExported() {
continue
}
// Hooks on exported embedded fields should be called.
if fieldType.Anonymous {
receivers = append(receivers, field)
continue
}
// Hooks on exported fields that are not exported,
// but are tagged with `embed:""` should be called.
if _, ok := fieldType.Tag.Lookup("embed"); ok {
receivers = append(receivers, field)
}
}
}
// Search all receivers for methods
var methods []reflect.Value
for _, receiver := range receivers {
if method := getMethod(receiver, name); method.IsValid() {
methods = append(methods, method)
}
}
return methods
}
func callFunction(f reflect.Value, bindings bindings) error {
if f.Kind() != reflect.Func {
return fmt.Errorf("expected function, got %s", f.Type())
}
t := f.Type()
if t.NumOut() != 1 || !t.Out(0).Implements(callbackReturnSignature) {
return fmt.Errorf("return value of %s must implement \"error\"", t)
}
out, err := callAnyFunction(f, bindings)
if err != nil {
return err
}
ferr := out[0]
if ferrv := reflect.ValueOf(ferr); !ferrv.IsValid() || ((ferrv.Kind() == reflect.Interface || ferrv.Kind() == reflect.Pointer) && ferrv.IsNil()) {
return nil
}
return ferr.(error) //nolint:forcetypeassert
}
func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error) {
if f.Kind() != reflect.Func {
return nil, fmt.Errorf("expected function, got %s", f.Type())
}
in := []reflect.Value{}
t := f.Type()
for i := 0; i < t.NumIn(); i++ {
pt := t.In(i)
argf, ok := bindings[pt]
if !ok {
return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt)
}
// Recursively resolve binding functions.
argv, err := callAnyFunction(reflect.ValueOf(argf), bindings)
if err != nil {
return nil, fmt.Errorf("%s: %w", pt, err)
}
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && !ferrv.IsNil() {
return nil, ferrv.Interface().(error) //nolint:forcetypeassert
}
in = append(in, reflect.ValueOf(argv[0]))
}
outv := f.Call(in)
out = make([]any, len(outv))
for i, v := range outv {
out[i] = v.Interface()
}
return out, nil
}