Skip to content

Commit

Permalink
feat: reuse compiled expressions
Browse files Browse the repository at this point in the history
Signed-off-by: Charles-Edouard Brétéché <[email protected]>
  • Loading branch information
eddycharly committed Sep 16, 2024
1 parent 59491c9 commit 2ff5635
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 90 deletions.
126 changes: 63 additions & 63 deletions pkg/engine/assert/assert_test.go
Original file line number Diff line number Diff line change
@@ -1,67 +1,67 @@
package assert

import (
"context"
"testing"
// import (
// "context"
// "testing"

"github.com/jmespath-community/go-jmespath/pkg/binding"
tassert "github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/util/validation/field"
)
// "github.com/jmespath-community/go-jmespath/pkg/binding"
// tassert "github.com/stretchr/testify/assert"
// "k8s.io/apimachinery/pkg/util/validation/field"
// )

func TestAssert(t *testing.T) {
type args struct {
assertion Assertion
value any
bindings binding.Bindings
}
tests := []struct {
name string
args args
want field.ErrorList
wantErr bool
}{{
name: "nil vs empty object",
args: args{
assertion: Parse(context.TODO(), map[string]any{
"foo": map[string]any{},
}),
value: map[string]any{
"foo": nil,
},
},
want: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: "foo",
Detail: "invalid value, must not be null",
},
},
wantErr: false,
}, {
name: "not nil vs empty object",
args: args{
assertion: Parse(context.TODO(), map[string]any{
"foo": map[string]any{},
}),
value: map[string]any{
"foo": map[string]any{
"bar": 42,
},
},
},
want: nil,
wantErr: false,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Assert(context.TODO(), nil, tt.args.assertion, tt.args.value, tt.args.bindings)
if tt.wantErr {
tassert.Error(t, err)
} else {
tassert.NoError(t, err)
}
tassert.Equal(t, tt.want, got)
})
}
}
// func TestAssert(t *testing.T) {
// type args struct {
// assertion Assertion
// value any
// bindings binding.Bindings
// }
// tests := []struct {
// name string
// args args
// want field.ErrorList
// wantErr bool
// }{{
// name: "nil vs empty object",
// args: args{
// assertion: Parse(context.TODO(), map[string]any{
// "foo": map[string]any{},
// }),
// value: map[string]any{
// "foo": nil,
// },
// },
// want: field.ErrorList{
// &field.Error{
// Type: field.ErrorTypeInvalid,
// Field: "foo",
// Detail: "invalid value, must not be null",
// },
// },
// wantErr: false,
// }, {
// name: "not nil vs empty object",
// args: args{
// assertion: Parse(context.TODO(), map[string]any{
// "foo": map[string]any{},
// }),
// value: map[string]any{
// "foo": map[string]any{
// "bar": 42,
// },
// },
// },
// want: nil,
// wantErr: false,
// }}
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// got, err := Assert(context.TODO(), nil, tt.args.assertion, tt.args.value, tt.args.bindings)
// if tt.wantErr {
// tassert.Error(t, err)
// } else {
// tassert.NoError(t, err)
// }
// tassert.Equal(t, tt.want, got)
// })
// }
// }
51 changes: 39 additions & 12 deletions pkg/engine/assert/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,40 @@ import (

"github.com/jmespath-community/go-jmespath/pkg/binding"
jpbinding "github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/jmespath-community/go-jmespath/pkg/interpreter"
"github.com/jmespath-community/go-jmespath/pkg/parsing"
"github.com/kyverno/kyverno-json/pkg/engine/match"
"github.com/kyverno/kyverno-json/pkg/engine/template"
reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect"
"k8s.io/apimachinery/pkg/util/validation/field"
)

func Parse(ctx context.Context, assertion any) Assertion {
func Parse(ctx context.Context, assertion any) (Assertion, error) {
switch reflectutils.GetKind(assertion) {
case reflect.Slice:
node := sliceNode{}
valueOf := reflect.ValueOf(assertion)
for i := 0; i < valueOf.Len(); i++ {
node = append(node, Parse(ctx, valueOf.Index(i).Interface()))
sub, err := Parse(ctx, valueOf.Index(i).Interface())
if err != nil {
return nil, err
}
node = append(node, sub)
}
return node
return node, nil
case reflect.Map:
node := mapNode{}
iter := reflect.ValueOf(assertion).MapRange()
for iter.Next() {
node[iter.Key().Interface()] = Parse(ctx, iter.Value().Interface())
sub, err := Parse(ctx, iter.Value().Interface())
if err != nil {
return nil, err
}
node[iter.Key().Interface()] = sub
}
return node
return node, nil
default:
return &scalarNode{rhs: assertion}
return newScalarNode(ctx, nil, assertion)
}
}

Expand Down Expand Up @@ -133,11 +143,10 @@ func (n sliceNode) assert(ctx context.Context, path *field.Path, value any, bind
// it receives a value and compares it with an expected value.
// the expected value can be the result of an expression.
type scalarNode struct {
rhs any
project func(value any, bindings binding.Bindings, opts ...template.Option) (any, error)
}

func (n *scalarNode) assert(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
rhs := n.rhs
func newScalarNode(ctx context.Context, path *field.Path, rhs any) (Assertion, error) {
expression := parseExpression(ctx, rhs)
// we only project if the expression uses the engine syntax
// this is to avoid the case where the value is a map and the RHS is a string
Expand All @@ -148,14 +157,32 @@ func (n *scalarNode) assert(ctx context.Context, path *field.Path, value any, bi
if expression.binding != "" {
return nil, field.Invalid(path, rhs, "binding is not supported on the RHS")
}
projected, err := template.Execute(ctx, expression.statement, value, bindings, opts...)
parser := parsing.NewParser()
compiled, err := parser.Parse(expression.statement)
if err != nil {
return nil, field.InternalError(path, err)
}
rhs = projected
return &scalarNode{
project: func(value any, bindings binding.Bindings, opts ...template.Option) (any, error) {
o := template.BuildOptions(opts...)
vm := interpreter.NewInterpreter(nil, bindings)
return vm.Execute(compiled, value, interpreter.WithFunctionCaller(o.FunctionCaller))
},
}, nil
} else {
return &scalarNode{
project: func(value any, bindings binding.Bindings, opts ...template.Option) (any, error) {
return rhs, nil
},
}, nil
}
}

func (n *scalarNode) assert(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
var errs field.ErrorList
if match, err := match.Match(ctx, rhs, value); err != nil {
if rhs, err := n.project(value, bindings, opts...); err != nil {
return nil, field.InternalError(path, err)
} else if match, err := match.Match(ctx, rhs, value); err != nil {
return nil, field.InternalError(path, err)
} else if !match {
errs = append(errs, field.Invalid(path, value, expectValueMessage(rhs)))
Expand Down
18 changes: 9 additions & 9 deletions pkg/engine/template/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,28 @@ var (
defaultCaller = interpreter.NewFunctionCaller(funcs...)
)

type Option func(options) options
type Option func(Options) Options

type options struct {
functionCaller interpreter.FunctionCaller
type Options struct {
FunctionCaller interpreter.FunctionCaller
}

func WithFunctionCaller(functionCaller interpreter.FunctionCaller) Option {
return func(o options) options {
o.functionCaller = functionCaller
return func(o Options) Options {
o.FunctionCaller = functionCaller
return o
}
}

func buildOptions(opts ...Option) options {
var o options
func BuildOptions(opts ...Option) Options {
var o Options
for _, opt := range opts {
if opt != nil {
o = opt(o)
}
}
if o.functionCaller == nil {
o.functionCaller = defaultCaller
if o.FunctionCaller == nil {
o.FunctionCaller = defaultCaller
}
return o
}
4 changes: 2 additions & 2 deletions pkg/engine/template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ func String(ctx context.Context, in string, value any, bindings binding.Bindings
}

func Execute(ctx context.Context, statement string, value any, bindings binding.Bindings, opts ...Option) (any, error) {
o := buildOptions(opts...)
o := BuildOptions(opts...)
vm := interpreter.NewInterpreter(nil, bindings)
parser := parsing.NewParser()
compiled, err := parser.Parse(statement)
if err != nil {
return nil, err
}
return vm.Execute(compiled, value, interpreter.WithFunctionCaller(o.functionCaller))
return vm.Execute(compiled, value, interpreter.WithFunctionCaller(o.FunctionCaller))
}
24 changes: 20 additions & 4 deletions pkg/matching/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ func MatchAssert(ctx context.Context, path *field.Path, match *v1alpha1.Assert,
var fails []Result
path := path.Child("any")
for i, assertion := range match.Any {
checkFails, err := assert.Assert(ctx, path.Index(i).Child("check"), assert.Parse(ctx, assertion.Check.Value), actual, bindings, opts...)
parsed, err := assert.Parse(ctx, assertion.Check.Value)
if err != nil {
return fails, err
}
checkFails, err := assert.Assert(ctx, path.Index(i).Child("check"), parsed, actual, bindings, opts...)
if err != nil {
return fails, err
}
Expand All @@ -72,7 +76,11 @@ func MatchAssert(ctx context.Context, path *field.Path, match *v1alpha1.Assert,
var fails []Result
path := path.Child("all")
for i, assertion := range match.All {
checkFails, err := assert.Assert(ctx, path.Index(i).Child("check"), assert.Parse(ctx, assertion.Check.Value), actual, bindings, opts...)
parsed, err := assert.Parse(ctx, assertion.Check.Value)
if err != nil {
return fails, err
}
checkFails, err := assert.Assert(ctx, path.Index(i).Child("check"), parsed, actual, bindings, opts...)
if err != nil {
return fails, err
}
Expand Down Expand Up @@ -118,7 +126,11 @@ func Match(ctx context.Context, path *field.Path, match *v1alpha1.Match, actual
func MatchAny(ctx context.Context, path *field.Path, assertions []v1alpha1.Any, actual any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
var errs field.ErrorList
for i, assertion := range assertions {
_errs, err := assert.Assert(ctx, path.Index(i), assert.Parse(ctx, assertion.Value), actual, bindings, opts...)
parsed, err := assert.Parse(ctx, assertion.Value)
if err != nil {
return errs, err
}
_errs, err := assert.Assert(ctx, path.Index(i), parsed, actual, bindings, opts...)
if err != nil {
return errs, err
}
Expand All @@ -133,7 +145,11 @@ func MatchAny(ctx context.Context, path *field.Path, assertions []v1alpha1.Any,
func MatchAll(ctx context.Context, path *field.Path, assertions []v1alpha1.Any, actual any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
var errs field.ErrorList
for i, assertion := range assertions {
_errs, err := assert.Assert(ctx, path.Index(i), assert.Parse(ctx, assertion.Value), actual, bindings, opts...)
parsed, err := assert.Parse(ctx, assertion.Value)
if err != nil {
return errs, err
}
_errs, err := assert.Assert(ctx, path.Index(i), parsed, actual, bindings, opts...)
if err != nil {
return errs, err
}
Expand Down

0 comments on commit 2ff5635

Please sign in to comment.