diff --git a/callbacks.go b/callbacks.go index f488395..8b78741 100644 --- a/callbacks.go +++ b/callbacks.go @@ -79,8 +79,8 @@ func getMethod(value reflect.Value, name string) reflect.Value { func callMethod(name string, v, f reflect.Value, bindings bindings) error { in := []reflect.Value{} t := f.Type() - if t.NumOut() != 1 || t.Out(0) != callbackReturnSignature { - return fmt.Errorf("return value of %T.%s() must be exactly \"error\"", v.Type(), name) + if t.NumOut() != 1 || !t.Out(0).Implements(callbackReturnSignature) { + return fmt.Errorf("return value of %T.%s() must implement \"error\"", v.Type(), name) } for i := 0; i < t.NumIn(); i++ { pt := t.In(i) diff --git a/options_test.go b/options_test.go index c48949e..4d401d9 100644 --- a/options_test.go +++ b/options_test.go @@ -42,6 +42,51 @@ func TestBindTo(t *testing.T) { require.Equal(t, "foo", saw) } +func TestInvalidCallback(t *testing.T) { + type iface interface { + Method() + } + + saw := "" + method := func(i iface) string { + saw = string(i.(impl)) + return saw + } + + var cli struct{} + + p, err := New(&cli, BindTo(impl("foo"), (*iface)(nil))) + require.NoError(t, err) + err = callMethod("method", reflect.ValueOf(impl("??")), reflect.ValueOf(method), p.bindings) + require.EqualError(t, err, `return value of *reflect.rtype.method() must implement "error"`) +} + +type zrror struct{} + +func (*zrror) Error() string { + return "error" +} + +func TestCallbackCustomError(t *testing.T) { + type iface interface { + Method() + } + + saw := "" + method := func(i iface) *zrror { + saw = string(i.(impl)) + return nil + } + + var cli struct{} + + p, err := New(&cli, BindTo(impl("foo"), (*iface)(nil))) + require.NoError(t, err) + err = callMethod("method", reflect.ValueOf(impl("??")), reflect.ValueOf(method), p.bindings) + require.NoError(t, err) + require.Equal(t, "foo", saw) +} + type bindToProviderCLI struct { Called bool Cmd bindToProviderCmd `cmd:""`