hooks: Recursively search embedded fields for methods (#494)

* hooks: Recursively search embedded fields for methods

Follow up to #493 and 840220c

Kong currently supports hooks on embedded fields of a parsed node,
but only at the first level of embedding:

```
type mainCmd struct {
    FooOptions
}

type FooOptions struct {
    BarOptions
}

func (f *FooOptions) BeforeApply() error {
    // this will be called
}

type BarOptions struct {
}

func (b *BarOptions) BeforeApply() error {
    // this will not be called
}
```

This change adds support for hooks to be defined
on embedded fields of embedded fields so that the above
example would work as expected.

Per #493, the definition of "embedded" field is adjusted to mean:

- Any anonymous (Go-embedded) field that is exported
- Any non-anonymous field that is tagged with `embed:""`

*Testing*:
Includes a test case for embedding an anonymous field in an `embed:""`
and an `embed:""` field in an anonymous field.

* Use recursion to build up the list of receivers

The 'receivers' parameter helps avoid constant memory allocation
as the backing storage for the slice is reused across recursive calls.
This commit is contained in:
Abhinav Gupta
2025-01-29 18:43:10 -08:00
committed by GitHub
parent 4e1757c0e8
commit 4be6ae6168
2 changed files with 60 additions and 23 deletions
+34 -16
View File
@@ -77,35 +77,53 @@ func getMethod(value reflect.Value, name string) reflect.Value {
return method return method
} }
// Get methods from the given value and any embedded fields. // getMethods gets all methods with the given name from the given value
// and any embedded fields.
//
// Returns a slice of bound methods that can be called directly.
func getMethods(value reflect.Value, name string) []reflect.Value { func getMethods(value reflect.Value, name string) []reflect.Value {
// Collect all possible receivers // Traverses embedded fields of the struct
receivers := []reflect.Value{value} // starting from the given value to collect all possible receivers
// for the given method name.
var traverse func(value reflect.Value, receivers []reflect.Value) []reflect.Value
traverse = func(value reflect.Value, receivers []reflect.Value) []reflect.Value {
// Always consider the current value for hooks.
receivers = append(receivers, value)
if value.Kind() == reflect.Ptr { if value.Kind() == reflect.Ptr {
value = value.Elem() value = value.Elem()
} }
// If the current value is a struct, also consider embedded fields.
// Two kinds of embedded fields are considered if they're exported:
//
// - standard Go embedded fields
// - fields tagged with `embed:""`
if value.Kind() == reflect.Struct { if value.Kind() == reflect.Struct {
t := value.Type() t := value.Type()
for i := 0; i < value.NumField(); i++ { for i := 0; i < value.NumField(); i++ {
field := value.Field(i) fieldValue := value.Field(i)
fieldType := t.Field(i) field := t.Field(i)
if !fieldType.IsExported() {
if !field.IsExported() {
continue continue
} }
// Hooks on exported embedded fields should be called. // Consider a field embedded if it's actually embedded
if fieldType.Anonymous { // or if it's tagged with `embed:""`.
receivers = append(receivers, field) _, isEmbedded := field.Tag.Lookup("embed")
continue isEmbedded = isEmbedded || field.Anonymous
if isEmbedded {
receivers = traverse(fieldValue, receivers)
}
}
} }
// Hooks on exported fields that are not exported, return receivers
// but are tagged with `embed:""` should be called.
if _, ok := fieldType.Tag.Lookup("embed"); ok {
receivers = append(receivers, field)
}
}
} }
receivers := traverse(value, nil /* receivers */)
// Search all receivers for methods // Search all receivers for methods
var methods []reflect.Value var methods []reflect.Value
for _, receiver := range receivers { for _, receiver := range receivers {
+19
View File
@@ -2405,6 +2405,8 @@ func TestProviderMethods(t *testing.T) {
} }
type EmbeddedCallback struct { type EmbeddedCallback struct {
Nested NestedCallback `embed:""`
Embedded bool Embedded bool
} }
@@ -2414,6 +2416,8 @@ func (e *EmbeddedCallback) AfterApply() error {
} }
type taggedEmbeddedCallback struct { type taggedEmbeddedCallback struct {
NestedCallback
Tagged bool Tagged bool
} }
@@ -2422,6 +2426,15 @@ func (e *taggedEmbeddedCallback) AfterApply() error {
return nil return nil
} }
type NestedCallback struct {
nested bool
}
func (n *NestedCallback) AfterApply() error {
n.nested = true
return nil
}
type EmbeddedRoot struct { type EmbeddedRoot struct {
EmbeddedCallback EmbeddedCallback
Tagged taggedEmbeddedCallback `embed:""` Tagged taggedEmbeddedCallback `embed:""`
@@ -2441,9 +2454,15 @@ func TestEmbeddedCallbacks(t *testing.T) {
expected := &EmbeddedRoot{ expected := &EmbeddedRoot{
EmbeddedCallback: EmbeddedCallback{ EmbeddedCallback: EmbeddedCallback{
Embedded: true, Embedded: true,
Nested: NestedCallback{
nested: true,
},
}, },
Tagged: taggedEmbeddedCallback{ Tagged: taggedEmbeddedCallback{
Tagged: true, Tagged: true,
NestedCallback: NestedCallback{
nested: true,
},
}, },
Root: true, Root: true,
} }