diff --git a/callbacks.go b/callbacks.go index 2a296d0..6096a26 100644 --- a/callbacks.go +++ b/callbacks.go @@ -6,10 +6,59 @@ import ( "strings" ) +// binding is a single binding registered with Kong. +type binding struct { + // fn is a function that returns a value of the target type. + fn reflect.Value + + // val is a value of the target type. + // Must be set if done and singleton are true. + val reflect.Value + + // singleton indicates whether the binding is a singleton. + // If true, the binding will be resolved once and cached. + singleton bool + + // done indicates whether a singleton binding has been resolved. + // If singleton is false, this field is ignored. + done bool +} + +// newValueBinding builds a binding with an already resolved value. +func newValueBinding(v reflect.Value) *binding { + return &binding{val: v, done: true, singleton: true} +} + +// newFunctionBinding builds a binding with a function +// that will return a value of the target type. +// +// The function signature must be func(...) (T, error) or func(...) T +// where parameters are recursively resolved. +func newFunctionBinding(f reflect.Value, singleton bool) *binding { + return &binding{fn: f, singleton: singleton} +} + +// Get returns the pre-resolved value for the binding, +// or false if the binding is not resolved. +func (b *binding) Get() (v reflect.Value, ok bool) { + return b.val, b.done +} + +// Set sets the value of the binding to the given value, +// marking it as resolved. +// +// If the binding is not a singleton, this method does nothing. +func (b *binding) Set(v reflect.Value) { + if b.singleton { + b.val = v + b.done = true + } +} + // A map of type to function that returns a value of that type. // // The function should have the signature func(...) (T, error). Arguments are recursively resolved. -type bindings map[reflect.Type]any +type bindings map[reflect.Type]*binding func (b bindings) String() string { out := []string{} @@ -21,17 +70,18 @@ func (b bindings) String() string { func (b bindings) add(values ...any) bindings { for _, v := range values { - v := v - b[reflect.TypeOf(v)] = func() (any, error) { return v, nil } + val := reflect.ValueOf(v) + b[val.Type()] = newValueBinding(val) } return b } func (b bindings) addTo(impl, iface any) { - b[reflect.TypeOf(iface).Elem()] = func() (any, error) { return impl, nil } + val := reflect.ValueOf(impl) + b[reflect.TypeOf(iface).Elem()] = newValueBinding(val) } -func (b bindings) addProvider(provider any) error { +func (b bindings) addProvider(provider any, singleton bool) error { pv := reflect.ValueOf(provider) t := pv.Type() if t.Kind() != reflect.Func { @@ -47,7 +97,7 @@ func (b bindings) addProvider(provider any) error { } } rt := pv.Type().Out(0) - b[rt] = provider + b[rt] = newFunctionBinding(pv, singleton) return nil } @@ -148,19 +198,29 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error) t := f.Type() for i := 0; i < t.NumIn(); i++ { pt := t.In(i) - argf, ok := bindings[pt] + binding, ok := bindings[pt] if !ok { return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt) } + + // Don't need to call the function if the value is already resolved. + if val, ok := binding.Get(); ok { + in = append(in, val) + continue + } + // Recursively resolve binding functions. - argv, err := callAnyFunction(reflect.ValueOf(argf), bindings) + argv, err := callAnyFunction(binding.fn, bindings) if err != nil { return nil, fmt.Errorf("%s: %w", pt, err) } if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && ferrv.Type().Implements(callbackReturnSignature) && !ferrv.IsNil() { return nil, ferrv.Interface().(error) //nolint:forcetypeassert } - in = append(in, reflect.ValueOf(argv[0])) + + val := reflect.ValueOf(argv[0]) + binding.Set(val) + in = append(in, val) } outv := f.Call(in) out = make([]any, len(outv)) diff --git a/context.go b/context.go index ebf4c31..7b1d482 100644 --- a/context.go +++ b/context.go @@ -120,10 +120,19 @@ func (c *Context) BindTo(impl, iface any) { // This is useful when the Run() function of different commands require different values that may // not all be initialisable from the main() function. // -// "provider" must be a function with the signature func(...) (T, error) or func(...) T, where -// ... will be recursively injected with bound values. +// "provider" must be a function with the signature func(...) (T, error) or func(...) T, +// where ... will be recursively injected with bound values. func (c *Context) BindToProvider(provider any) error { - return c.bindings.addProvider(provider) + return c.bindings.addProvider(provider, false /* singleton */) +} + +// BindSingletonProvider allows binding of provider functions. +// The provider will be called once and the result cached. +// +// "provider" must be a function with the signature func(...) (T, error) or func(...) T, +// where ... will be recursively injected with bound values. +func (c *Context) BindSingletonProvider(provider any) error { + return c.bindings.addProvider(provider, true /* singleton */) } // Value returns the value for a particular path element. @@ -792,7 +801,7 @@ func (c *Context) RunNode(node *Node, binds ...any) (err error) { methodt := t.Method(i) if strings.HasPrefix(methodt.Name, "Provide") { method := p.Method(i) - if err := methodBinds.addProvider(method.Interface()); err != nil { + if err := methodBinds.addProvider(method.Interface(), false /* singleton */); err != nil { return fmt.Errorf("%s.%s: %w", t.Name(), methodt.Name, err) } } diff --git a/options.go b/options.go index 6263202..d20b2fb 100644 --- a/options.go +++ b/options.go @@ -210,15 +210,33 @@ func BindTo(impl, iface any) Option { // BindToProvider binds an injected value to a provider function. // -// The provider function must have the signature: +// The provider function must have one of the following signatures: // -// func() (any, error) +// func(...) (T, error) +// func(...) T +// +// Where arguments to the function are injected by Kong. // // This is useful when the Run() function of different commands require different values that may // not all be initialisable from the main() function. func BindToProvider(provider any) Option { return OptionFunc(func(k *Kong) error { - return k.bindings.addProvider(provider) + return k.bindings.addProvider(provider, false /* singleton */) + }) +} + +// BindSingletonProvider binds an injected value to a provider function. +// The provider function must have the signature: +// +// func(...) (T, error) +// func(...) T +// +// Unlike [BindToProvider], the provider function will only be called +// at most once, and the result will be cached and reused +// across multiple recipients of the injected value. +func BindSingletonProvider(provider any) Option { + return OptionFunc(func(k *Kong) error { + return k.bindings.addProvider(provider, true /* singleton */) }) } diff --git a/options_test.go b/options_test.go index e549475..791cb64 100644 --- a/options_test.go +++ b/options_test.go @@ -119,6 +119,43 @@ func TestBindToProvider(t *testing.T) { assert.True(t, cli.Called) } +func TestBindSingletonProvider(t *testing.T) { + type ( + Connection struct{} + ClientA struct{ conn *Connection } + ClientB struct{ conn *Connection } + ) + + var numConnections int + newConnection := func() *Connection { + numConnections++ + return &Connection{} + } + + var cli struct{} + app, err := New(&cli, + BindSingletonProvider(newConnection), + BindToProvider(func(conn *Connection) *ClientA { + return &ClientA{conn: conn} + }), + BindToProvider(func(conn *Connection) *ClientB { + return &ClientB{conn: conn} + }), + ) + assert.NoError(t, err) + + ctx, err := app.Parse([]string{}) + assert.NoError(t, err) + + _, err = ctx.Call(func(a *ClientA, b *ClientB) { + assert.NotZero(t, a.conn) + assert.NotZero(t, b.conn) + + assert.Equal(t, 1, numConnections, "expected newConnection to be called only once") + }) + assert.NoError(t, err) +} + func TestFlagNamer(t *testing.T) { var cli struct { SomeFlag string