feat: allow use of providers that don't return errors

This commit is contained in:
Alec Thomas
2025-01-30 13:39:31 +11:00
parent 9c08a58eb2
commit 4e1757c0e8
3 changed files with 35 additions and 3 deletions
+12 -3
View File
@@ -34,8 +34,17 @@ func (b bindings) addTo(impl, iface any) {
func (b bindings) addProvider(provider any) error {
pv := reflect.ValueOf(provider)
t := pv.Type()
if t.Kind() != reflect.Func || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
return fmt.Errorf("%T must be a function with the signature func(...)(T, error)", provider)
if t.Kind() != reflect.Func {
return fmt.Errorf("%T must be a function", provider)
}
if t.NumOut() == 0 {
return fmt.Errorf("%T must be a function with the signature func(...)(T, error) or func(...) T", provider)
}
if t.NumOut() == 2 {
if t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
return fmt.Errorf("missing error; %T must be a function with the signature func(...)(T, error) or func(...) T", provider)
}
}
rt := pv.Type().Out(0)
b[rt] = provider
@@ -143,7 +152,7 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error)
if err != nil {
return nil, fmt.Errorf("%s: %w", pt, err)
}
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && !ferrv.IsNil() {
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && ferrv.Type().Implements(callbackReturnSignature) && !ferrv.IsNil() {
return nil, ferrv.Interface().(error) //nolint:forcetypeassert
}
in = append(in, reflect.ValueOf(argv[0]))
+3
View File
@@ -119,6 +119,9 @@ func (c *Context) BindTo(impl, iface any) {
//
// This is useful when the Run() function of different commands require different values that may
// not all be initialisable from the main() function.
//
// "provider" must be a function with the signature func(...) (T, error) or func(...) T, where
// ... will be recursively injected with bound values.
func (c *Context) BindToProvider(provider any) error {
return c.bindings.addProvider(provider)
}
+20
View File
@@ -2521,3 +2521,23 @@ func TestIssue483EmptyRootNodeNoRun(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "no command selected")
}
type providerWithoutErrorCLI struct {
}
func (p *providerWithoutErrorCLI) Run(name string) error {
if name == "Bob" {
return nil
}
return fmt.Errorf("name %s is not Bob", name)
}
func TestProviderWithoutError(t *testing.T) {
k := mustNew(t, &providerWithoutErrorCLI{})
kctx, err := k.Parse(nil)
assert.NoError(t, err)
err = kctx.BindToProvider(func() string { return "Bob" })
assert.NoError(t, err)
err = kctx.Run()
assert.NoError(t, err)
}