diff --git a/pkg/engine/assert/parse.go b/pkg/engine/assert/parse.go index 338d84b3..736896f0 100644 --- a/pkg/engine/assert/parse.go +++ b/pkg/engine/assert/parse.go @@ -7,7 +7,6 @@ import ( "github.com/jmespath-community/go-jmespath/pkg/binding" jpbinding "github.com/jmespath-community/go-jmespath/pkg/binding" - "github.com/jmespath-community/go-jmespath/pkg/interpreter" "github.com/jmespath-community/go-jmespath/pkg/parsing" "github.com/kyverno/kyverno-json/pkg/engine/match" "github.com/kyverno/kyverno-json/pkg/engine/template" @@ -15,13 +14,13 @@ import ( "k8s.io/apimachinery/pkg/util/validation/field" ) -func Parse(ctx context.Context, assertion any) (Assertion, error) { +func Parse(ctx context.Context, path *field.Path, assertion any) (Assertion, error) { switch reflectutils.GetKind(assertion) { case reflect.Slice: node := sliceNode{} valueOf := reflect.ValueOf(assertion) for i := 0; i < valueOf.Len(); i++ { - sub, err := Parse(ctx, valueOf.Index(i).Interface()) + sub, err := Parse(ctx, path.Index(i), valueOf.Index(i).Interface()) if err != nil { return nil, err } @@ -32,16 +31,16 @@ func Parse(ctx context.Context, assertion any) (Assertion, error) { node := mapNode{} iter := reflect.ValueOf(assertion).MapRange() for iter.Next() { - sub, err := Parse(ctx, iter.Value().Interface()) + key := iter.Key().Interface() + sub, err := Parse(ctx, path.Child(fmt.Sprint(key)), iter.Value().Interface()) if err != nil { return nil, err } - node[iter.Key().Interface()] = sub + node[key] = sub } return node, nil default: - // TODO: propagate path - return newScalarNode(ctx, nil, assertion) + return newScalarNode(ctx, path, assertion) } } @@ -165,9 +164,7 @@ func newScalarNode(ctx context.Context, path *field.Path, rhs any) (Assertion, e } return &scalarNode{ project: func(value any, bindings binding.Bindings, opts ...template.Option) (any, error) { - o := template.BuildOptions(opts...) - vm := interpreter.NewInterpreter(nil, bindings) - return vm.Execute(compiled, value, interpreter.WithFunctionCaller(o.FunctionCaller)) + return template.ExecuteAST(ctx, compiled, value, bindings, opts...) }, }, nil } else { diff --git a/pkg/engine/template/options.go b/pkg/engine/template/options.go index d2ff48e2..8062622e 100644 --- a/pkg/engine/template/options.go +++ b/pkg/engine/template/options.go @@ -11,28 +11,28 @@ var ( defaultCaller = interpreter.NewFunctionCaller(funcs...) ) -type Option func(Options) Options +type Option func(options) options -type Options struct { - FunctionCaller interpreter.FunctionCaller +type options struct { + functionCaller interpreter.FunctionCaller } func WithFunctionCaller(functionCaller interpreter.FunctionCaller) Option { - return func(o Options) Options { - o.FunctionCaller = functionCaller + return func(o options) options { + o.functionCaller = functionCaller return o } } -func BuildOptions(opts ...Option) Options { - var o Options +func buildOptions(opts ...Option) options { + var o options for _, opt := range opts { if opt != nil { o = opt(o) } } - if o.FunctionCaller == nil { - o.FunctionCaller = defaultCaller + if o.functionCaller == nil { + o.functionCaller = defaultCaller } return o } diff --git a/pkg/engine/template/template.go b/pkg/engine/template/template.go index 6a4a28a0..df9b7962 100644 --- a/pkg/engine/template/template.go +++ b/pkg/engine/template/template.go @@ -32,12 +32,16 @@ func String(ctx context.Context, in string, value any, bindings binding.Bindings } func Execute(ctx context.Context, statement string, value any, bindings binding.Bindings, opts ...Option) (any, error) { - o := BuildOptions(opts...) - vm := interpreter.NewInterpreter(nil, bindings) parser := parsing.NewParser() compiled, err := parser.Parse(statement) if err != nil { return nil, err } - return vm.Execute(compiled, value, interpreter.WithFunctionCaller(o.FunctionCaller)) + return ExecuteAST(ctx, compiled, value, bindings, opts...) +} + +func ExecuteAST(ctx context.Context, ast parsing.ASTNode, value any, bindings binding.Bindings, opts ...Option) (any, error) { + o := buildOptions(opts...) + vm := interpreter.NewInterpreter(nil, bindings) + return vm.Execute(ast, value, interpreter.WithFunctionCaller(o.functionCaller)) } diff --git a/pkg/matching/match.go b/pkg/matching/match.go index 71f719d5..a4a25362 100644 --- a/pkg/matching/match.go +++ b/pkg/matching/match.go @@ -48,7 +48,7 @@ func MatchAssert(ctx context.Context, path *field.Path, match *v1alpha1.Assert, var fails []Result path := path.Child("any") for i, assertion := range match.Any { - parsed, err := assert.Parse(ctx, assertion.Check.Value) + parsed, err := assert.Parse(ctx, path.Index(i).Child("check"), assertion.Check.Value) if err != nil { return fails, err } @@ -76,7 +76,7 @@ func MatchAssert(ctx context.Context, path *field.Path, match *v1alpha1.Assert, var fails []Result path := path.Child("all") for i, assertion := range match.All { - parsed, err := assert.Parse(ctx, assertion.Check.Value) + parsed, err := assert.Parse(ctx, path.Index(i).Child("check"), assertion.Check.Value) if err != nil { return fails, err } @@ -126,7 +126,7 @@ func Match(ctx context.Context, path *field.Path, match *v1alpha1.Match, actual func MatchAny(ctx context.Context, path *field.Path, assertions []v1alpha1.Any, actual any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { var errs field.ErrorList for i, assertion := range assertions { - parsed, err := assert.Parse(ctx, assertion.Value) + parsed, err := assert.Parse(ctx, path.Index(i), assertion.Value) if err != nil { return errs, err } @@ -145,7 +145,7 @@ func MatchAny(ctx context.Context, path *field.Path, assertions []v1alpha1.Any, func MatchAll(ctx context.Context, path *field.Path, assertions []v1alpha1.Any, actual any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { var errs field.ErrorList for i, assertion := range assertions { - parsed, err := assert.Parse(ctx, assertion.Value) + parsed, err := assert.Parse(ctx, path.Index(i), assertion.Value) if err != nil { return errs, err }