Skip to content

Commit 4857426

Browse files
committed
fix(patcher): inject ctx into nested custom funcs
WithContext patcher now looks up function types from the Functions table when the callee type is interface{}. This fixes context injection for custom functions nested as arguments inside method calls with unknown callee types (e.g., now2().After(date2())). Also improved the regression test to actually verify context is passed to both functions, which would have caught this bug. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
1 parent d472286 commit 4857426

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

expr.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,12 @@ func EnableBuiltin(name string) Option {
195195

196196
// WithContext passes context to all functions calls with a context.Context argument.
197197
func WithContext(name string) Option {
198-
return Patch(patcher.WithContext{
199-
Name: name,
200-
})
198+
return func(c *conf.Config) {
199+
c.Visitors = append(c.Visitors, patcher.WithContext{
200+
Name: name,
201+
Functions: c.Functions,
202+
})
203+
}
201204
}
202205

203206
// Timezone sets default timezone for date() and now() builtin functions.

patcher/with_context.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ import (
44
"reflect"
55

66
"github.com/expr-lang/expr/ast"
7+
"github.com/expr-lang/expr/conf"
78
)
89

910
// WithContext adds WithContext.Name argument to all functions calls with a context.Context argument.
1011
type WithContext struct {
11-
Name string
12+
Name string
13+
Functions conf.FunctionsTable // Optional: used to look up function types when callee type is unknown.
1214
}
1315

1416
// Visit adds WithContext.Name argument to all functions calls with a context.Context argument.
@@ -19,6 +21,16 @@ func (w WithContext) Visit(node *ast.Node) {
1921
if fn == nil {
2022
return
2123
}
24+
// If callee type is interface{} (unknown), try to look up the function type
25+
// from the Functions table. This handles cases where the call is nested
26+
// inside another call with an unknown callee type.
27+
if fn.Kind() == reflect.Interface && w.Functions != nil {
28+
if ident, ok := call.Callee.(*ast.IdentifierNode); ok {
29+
if f, ok := w.Functions[ident.Value]; ok {
30+
fn = f.Type()
31+
}
32+
}
33+
}
2234
if fn.Kind() != reflect.Func {
2335
return
2436
}

test/issues/823/issue_test.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package issue_test
22

33
import (
44
"context"
5-
"fmt"
65
"testing"
76
"time"
87

@@ -14,26 +13,45 @@ type env struct {
1413
Ctx context.Context `expr:"ctx"`
1514
}
1615

16+
// TestIssue823 verifies that WithContext injects context into nested custom
17+
// function calls. The bug was that date2() nested as an argument to After()
18+
// didn't receive the context because its callee type was unknown.
1719
func TestIssue823(t *testing.T) {
20+
now2Called := false
21+
date2Called := false
22+
1823
p, err := expr.Compile(
1924
"now2().After(date2())",
2025
expr.Env(env{}),
2126
expr.WithContext("ctx"),
2227
expr.Function(
2328
"now2",
24-
func(params ...any) (any, error) { return time.Now(), nil },
29+
func(params ...any) (any, error) {
30+
require.Len(t, params, 1, "now2 should receive context")
31+
_, ok := params[0].(context.Context)
32+
require.True(t, ok, "now2 first param should be context.Context")
33+
now2Called = true
34+
return time.Now(), nil
35+
},
2536
new(func(context.Context) time.Time),
2637
),
2738
expr.Function(
2839
"date2",
29-
func(params ...any) (any, error) { return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), nil },
40+
func(params ...any) (any, error) {
41+
require.Len(t, params, 1, "date2 should receive context")
42+
_, ok := params[0].(context.Context)
43+
require.True(t, ok, "date2 first param should be context.Context")
44+
date2Called = true
45+
return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), nil
46+
},
3047
new(func(context.Context) time.Time),
3148
),
3249
)
33-
fmt.Printf("Compile result err: %v\n", err)
3450
require.NoError(t, err)
3551

3652
r, err := expr.Run(p, &env{Ctx: context.Background()})
3753
require.NoError(t, err)
3854
require.True(t, r.(bool))
55+
require.True(t, now2Called, "now2 should have been called")
56+
require.True(t, date2Called, "date2 should have been called")
3957
}

0 commit comments

Comments
 (0)