feat: allow use of providers that don't return errors

This commit is contained in:
Alec Thomas
2025-01-30 13:39:31 +11:00
parent 9c08a58eb2
commit 4e1757c0e8
3 changed files with 35 additions and 3 deletions
+12 -3
View File
@@ -34,8 +34,17 @@ func (b bindings) addTo(impl, iface any) {
func (b bindings) addProvider(provider any) error {
pv := reflect.ValueOf(provider)
t := pv.Type()
if t.Kind() != reflect.Func || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
return fmt.Errorf("%T must be a function with the signature func(...)(T, error)", provider)
if t.Kind() != reflect.Func {
return fmt.Errorf("%T must be a function", provider)
}
if t.NumOut() == 0 {
return fmt.Errorf("%T must be a function with the signature func(...)(T, error) or func(...) T", provider)
}
if t.NumOut() == 2 {
if t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
return fmt.Errorf("missing error; %T must be a function with the signature func(...)(T, error) or func(...) T", provider)
}
}
rt := pv.Type().Out(0)
b[rt] = provider
@@ -143,7 +152,7 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error)
if err != nil {
return nil, fmt.Errorf("%s: %w", pt, err)
}
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && !ferrv.IsNil() {
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && ferrv.Type().Implements(callbackReturnSignature) && !ferrv.IsNil() {
return nil, ferrv.Interface().(error) //nolint:forcetypeassert
}
in = append(in, reflect.ValueOf(argv[0]))