Skip to content

Commit

Permalink
Merge branch 'master' into sum-predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
antonmedv committed Apr 13, 2024
2 parents 0ef7dc5 + d66ffcd commit 0cf2376
Show file tree
Hide file tree
Showing 22 changed files with 941 additions and 93 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ func main() {
* [Visually.io](https://visually.io) employs Expr as a business rule engine for its personalization targeting algorithm.
* [Akvorado](https://github.com/akvorado/akvorado) utilizes Expr to classify exporters and interfaces in network flows.
* [keda.sh](https://keda.sh) uses Expr to allow customization of its Kubernetes-based event-driven autoscaling.
* [Span Digital](https://spandigital.com/) uses Expr in it's Knowledge Management products.

[Add your company too](https://github.com/expr-lang/expr/edit/master/README.md)

Expand Down
8 changes: 7 additions & 1 deletion ast/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,11 @@ func (n *MapNode) String() string {
}

func (n *PairNode) String() string {
return fmt.Sprintf("%s: %s", n.Key.String(), n.Value.String())
if str, ok := n.Key.(*StringNode); ok {
if utils.IsValidIdentifier(str.Value) {
return fmt.Sprintf("%s: %s", str.Value, n.Value.String())
}
return fmt.Sprintf("%q: %s", str.String(), n.Value.String())
}
return fmt.Sprintf("(%s): %s", n.Key.String(), n.Value.String())
}
5 changes: 3 additions & 2 deletions ast/print_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func TestPrint(t *testing.T) {
{`func(a)`, `func(a)`},
{`func(a, b)`, `func(a, b)`},
{`{}`, `{}`},
{`{a: b}`, `{"a": b}`},
{`{a: b, c: d}`, `{"a": b, "c": d}`},
{`{a: b}`, `{a: b}`},
{`{a: b, c: d}`, `{a: b, c: d}`},
{`[]`, `[]`},
{`[a]`, `[a]`},
{`[a, b]`, `[a, b]`},
Expand All @@ -71,6 +71,7 @@ func TestPrint(t *testing.T) {
{`a[1:]`, `a[1:]`},
{`a[:]`, `a[:]`},
{`(nil ?? 1) > 0`, `(nil ?? 1) > 0`},
{`{("a" + "b"): 42}`, `{("a" + "b"): 42}`},
}

for _, tt := range tests {
Expand Down
14 changes: 14 additions & 0 deletions builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,3 +608,17 @@ func TestBuiltin_bitOpsFunc(t *testing.T) {
})
}
}

type customInt int

func Test_int_unwraps_underlying_value(t *testing.T) {
env := map[string]any{
"customInt": customInt(42),
}
program, err := expr.Compile(`int(customInt) == 42`, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
assert.Equal(t, true, out)
}
4 changes: 4 additions & 0 deletions builtin/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ func Int(x any) any {
}
return i
default:
val := reflect.ValueOf(x)
if val.CanConvert(integerType) {
return val.Convert(integerType).Interface()
}
panic(fmt.Sprintf("invalid operation: int(%T)", x))
}
}
Expand Down
2 changes: 1 addition & 1 deletion checker/checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ func TestCheck_TaggedFieldName(t *testing.T) {
tree, err := parser.Parse(`foo.bar`)
require.NoError(t, err)

config := &conf.Config{}
config := conf.CreateNew()
expr.Env(struct {
x struct {
y bool `expr:"bar"`
Expand Down
48 changes: 24 additions & 24 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,34 +395,12 @@ func (c *compiler) UnaryNode(node *ast.UnaryNode) {
}

func (c *compiler) BinaryNode(node *ast.BinaryNode) {
l := kind(node.Left)
r := kind(node.Right)

leftIsSimple := isSimpleType(node.Left)
rightIsSimple := isSimpleType(node.Right)
leftAndRightAreSimple := leftIsSimple && rightIsSimple

switch node.Operator {
case "==":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)

if l == r && l == reflect.Int && leftAndRightAreSimple {
c.emit(OpEqualInt)
} else if l == r && l == reflect.String && leftAndRightAreSimple {
c.emit(OpEqualString)
} else {
c.emit(OpEqual)
}
c.equalBinaryNode(node)

case "!=":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpEqual)
c.equalBinaryNode(node)
c.emit(OpNot)

case "or", "||":
Expand Down Expand Up @@ -580,6 +558,28 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
}
}

func (c *compiler) equalBinaryNode(node *ast.BinaryNode) {
l := kind(node.Left)
r := kind(node.Right)

leftIsSimple := isSimpleType(node.Left)
rightIsSimple := isSimpleType(node.Right)
leftAndRightAreSimple := leftIsSimple && rightIsSimple

c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)

if l == r && l == reflect.Int && leftAndRightAreSimple {
c.emit(OpEqualInt)
} else if l == r && l == reflect.String && leftAndRightAreSimple {
c.emit(OpEqualString)
} else {
c.emit(OpEqual)
}
}

func isSimpleType(node ast.Node) bool {
if node == nil {
return false
Expand Down
33 changes: 33 additions & 0 deletions compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,39 @@ func TestCompile_optimizes_jumps(t *testing.T) {
{vm.OpFetch, 0},
},
},
{
`-1 not in [1, 2, 5]`,
[]op{
{vm.OpPush, 0},
{vm.OpPush, 1},
{vm.OpIn, 0},
{vm.OpNot, 0},
},
},
{
`1 + 8 not in [1, 2, 5]`,
[]op{
{vm.OpPush, 0},
{vm.OpPush, 1},
{vm.OpIn, 0},
{vm.OpNot, 0},
},
},
{
`true ? false : 8 not in [1, 2, 5]`,
[]op{
{vm.OpTrue, 0},
{vm.OpJumpIfFalse, 3},
{vm.OpPop, 0},
{vm.OpFalse, 0},
{vm.OpJump, 5},
{vm.OpPop, 0},
{vm.OpPush, 0},
{vm.OpPush, 1},
{vm.OpIn, 0},
{vm.OpNot, 0},
},
},
}

for _, test := range tests {
Expand Down
6 changes: 5 additions & 1 deletion conf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Config struct {
func CreateNew() *Config {
c := &Config{
Optimize: true,
Types: make(TypesTable),
ConstFns: make(map[string]reflect.Value),
Functions: make(map[string]*builtin.Function),
Builtins: make(map[string]*builtin.Function),
Expand Down Expand Up @@ -62,7 +63,10 @@ func (c *Config) WithEnv(env any) {
}

c.Env = env
c.Types = CreateTypesTable(env)
types := CreateTypesTable(env)
for name, t := range types {
c.Types[name] = t
}
c.MapEnv = mapEnv
c.DefaultType = mapValueType
c.Strict = true
Expand Down
Loading

0 comments on commit 0cf2376

Please sign in to comment.