diff --git a/callbacks.go b/callbacks.go index e541f5a..b4fe3ca 100644 --- a/callbacks.go +++ b/callbacks.go @@ -34,8 +34,17 @@ func (b bindings) addTo(impl, iface any) { func (b bindings) addProvider(provider any) error { pv := reflect.ValueOf(provider) t := pv.Type() - if t.Kind() != reflect.Func || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() { - return fmt.Errorf("%T must be a function with the signature func(...)(T, error)", provider) + if t.Kind() != reflect.Func { + return fmt.Errorf("%T must be a function", provider) + } + + if t.NumOut() == 0 { + return fmt.Errorf("%T must be a function with the signature func(...)(T, error) or func(...) T", provider) + } + if t.NumOut() == 2 { + if t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() { + return fmt.Errorf("missing error; %T must be a function with the signature func(...)(T, error) or func(...) T", provider) + } } rt := pv.Type().Out(0) b[rt] = provider @@ -143,7 +152,7 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error) if err != nil { return nil, fmt.Errorf("%s: %w", pt, err) } - if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && !ferrv.IsNil() { + if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && ferrv.Type().Implements(callbackReturnSignature) && !ferrv.IsNil() { return nil, ferrv.Interface().(error) //nolint:forcetypeassert } in = append(in, reflect.ValueOf(argv[0])) diff --git a/context.go b/context.go index b6a56e3..ebf4c31 100644 --- a/context.go +++ b/context.go @@ -119,6 +119,9 @@ func (c *Context) BindTo(impl, iface any) { // // This is useful when the Run() function of different commands require different values that may // not all be initialisable from the main() function. +// +// "provider" must be a function with the signature func(...) (T, error) or func(...) T, where +// ... will be recursively injected with bound values. func (c *Context) BindToProvider(provider any) error { return c.bindings.addProvider(provider) } diff --git a/kong_test.go b/kong_test.go index 874b90b..2ceb1b1 100644 --- a/kong_test.go +++ b/kong_test.go @@ -2521,3 +2521,23 @@ func TestIssue483EmptyRootNodeNoRun(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "no command selected") } + +type providerWithoutErrorCLI struct { +} + +func (p *providerWithoutErrorCLI) Run(name string) error { + if name == "Bob" { + return nil + } + return fmt.Errorf("name %s is not Bob", name) +} + +func TestProviderWithoutError(t *testing.T) { + k := mustNew(t, &providerWithoutErrorCLI{}) + kctx, err := k.Parse(nil) + assert.NoError(t, err) + err = kctx.BindToProvider(func() string { return "Bob" }) + assert.NoError(t, err) + err = kctx.Run() + assert.NoError(t, err) +}