From 4be6ae616831001dda40c5c34dc1a364eb8b2c68 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Wed, 29 Jan 2025 18:43:10 -0800 Subject: [PATCH] hooks: Recursively search embedded fields for methods (#494) * hooks: Recursively search embedded fields for methods Follow up to #493 and 840220c Kong currently supports hooks on embedded fields of a parsed node, but only at the first level of embedding: ``` type mainCmd struct { FooOptions } type FooOptions struct { BarOptions } func (f *FooOptions) BeforeApply() error { // this will be called } type BarOptions struct { } func (b *BarOptions) BeforeApply() error { // this will not be called } ``` This change adds support for hooks to be defined on embedded fields of embedded fields so that the above example would work as expected. Per #493, the definition of "embedded" field is adjusted to mean: - Any anonymous (Go-embedded) field that is exported - Any non-anonymous field that is tagged with `embed:""` *Testing*: Includes a test case for embedding an anonymous field in an `embed:""` and an `embed:""` field in an anonymous field. * Use recursion to build up the list of receivers The 'receivers' parameter helps avoid constant memory allocation as the backing storage for the slice is reused across recursive calls. --- callbacks.go | 64 +++++++++++++++++++++++++++++++++------------------- kong_test.go | 19 ++++++++++++++++ 2 files changed, 60 insertions(+), 23 deletions(-) diff --git a/callbacks.go b/callbacks.go index b4fe3ca..4644c54 100644 --- a/callbacks.go +++ b/callbacks.go @@ -77,35 +77,53 @@ func getMethod(value reflect.Value, name string) reflect.Value { return method } -// Get methods from the given value and any embedded fields. +// getMethods gets all methods with the given name from the given value +// and any embedded fields. +// +// Returns a slice of bound methods that can be called directly. func getMethods(value reflect.Value, name string) []reflect.Value { - // Collect all possible receivers - receivers := []reflect.Value{value} - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - if value.Kind() == reflect.Struct { - t := value.Type() - for i := 0; i < value.NumField(); i++ { - field := value.Field(i) - fieldType := t.Field(i) - if !fieldType.IsExported() { - continue - } + // Traverses embedded fields of the struct + // starting from the given value to collect all possible receivers + // for the given method name. + var traverse func(value reflect.Value, receivers []reflect.Value) []reflect.Value + traverse = func(value reflect.Value, receivers []reflect.Value) []reflect.Value { + // Always consider the current value for hooks. + receivers = append(receivers, value) - // Hooks on exported embedded fields should be called. - if fieldType.Anonymous { - receivers = append(receivers, field) - continue - } + if value.Kind() == reflect.Ptr { + value = value.Elem() + } - // Hooks on exported fields that are not exported, - // but are tagged with `embed:""` should be called. - if _, ok := fieldType.Tag.Lookup("embed"); ok { - receivers = append(receivers, field) + // If the current value is a struct, also consider embedded fields. + // Two kinds of embedded fields are considered if they're exported: + // + // - standard Go embedded fields + // - fields tagged with `embed:""` + if value.Kind() == reflect.Struct { + t := value.Type() + for i := 0; i < value.NumField(); i++ { + fieldValue := value.Field(i) + field := t.Field(i) + + if !field.IsExported() { + continue + } + + // Consider a field embedded if it's actually embedded + // or if it's tagged with `embed:""`. + _, isEmbedded := field.Tag.Lookup("embed") + isEmbedded = isEmbedded || field.Anonymous + if isEmbedded { + receivers = traverse(fieldValue, receivers) + } } } + + return receivers } + + receivers := traverse(value, nil /* receivers */) + // Search all receivers for methods var methods []reflect.Value for _, receiver := range receivers { diff --git a/kong_test.go b/kong_test.go index 2ceb1b1..6b5f5d6 100644 --- a/kong_test.go +++ b/kong_test.go @@ -2405,6 +2405,8 @@ func TestProviderMethods(t *testing.T) { } type EmbeddedCallback struct { + Nested NestedCallback `embed:""` + Embedded bool } @@ -2414,6 +2416,8 @@ func (e *EmbeddedCallback) AfterApply() error { } type taggedEmbeddedCallback struct { + NestedCallback + Tagged bool } @@ -2422,6 +2426,15 @@ func (e *taggedEmbeddedCallback) AfterApply() error { return nil } +type NestedCallback struct { + nested bool +} + +func (n *NestedCallback) AfterApply() error { + n.nested = true + return nil +} + type EmbeddedRoot struct { EmbeddedCallback Tagged taggedEmbeddedCallback `embed:""` @@ -2441,9 +2454,15 @@ func TestEmbeddedCallbacks(t *testing.T) { expected := &EmbeddedRoot{ EmbeddedCallback: EmbeddedCallback{ Embedded: true, + Nested: NestedCallback{ + nested: true, + }, }, Tagged: taggedEmbeddedCallback{ Tagged: true, + NestedCallback: NestedCallback{ + nested: true, + }, }, Root: true, }