From 54d5afad9fcf165df254ca214b51d532076c4b82 Mon Sep 17 00:00:00 2001
From: Kirill Semenchenko <k.semenchenko@maximatelecom.ru>
Date: Wed, 27 Jun 2018 18:41:35 +0300
Subject: [PATCH] Add IP and IPNET support

---
 TokenKind.go              |  6 +++
 dummies_test.go           | 48 +++++++++++++++++++-
 evaluationFailure_test.go | 93 +++++++++++++++++++++++++++++++++++++++
 evaluationStage.go        | 25 ++++++++++-
 evaluation_test.go        | 83 +++++++++++++++++++++++++++++++++-
 lexerState.go             | 26 +++++++++++
 parsing.go                | 41 +++++++++++++----
 parsingFailure_test.go    |  3 +-
 parsing_test.go           | 33 +++++++++++---
 stagePlanner.go           |  4 ++
 tokenKind_test.go         |  2 +
 11 files changed, 344 insertions(+), 20 deletions(-)

diff --git a/TokenKind.go b/TokenKind.go
index 7c9516d..c143c48 100644
--- a/TokenKind.go
+++ b/TokenKind.go
@@ -10,6 +10,8 @@ const (
 
 	PREFIX
 	NUMERIC
+	IP
+	IPNET
 	BOOLEAN
 	STRING
 	PATTERN
@@ -41,6 +43,10 @@ func (kind TokenKind) String() string {
 		return "PREFIX"
 	case NUMERIC:
 		return "NUMERIC"
+	case IP:
+		return "IP"
+	case IPNET:
+		return "IPNET"
 	case BOOLEAN:
 		return "BOOLEAN"
 	case STRING:
diff --git a/dummies_test.go b/dummies_test.go
index e3a1a2e..bd5e0fd 100644
--- a/dummies_test.go
+++ b/dummies_test.go
@@ -3,6 +3,7 @@ package govaluate
 import (
 	"errors"
 	"fmt"
+	"net"
 )
 
 /*
@@ -14,6 +15,10 @@ type dummyParameter struct {
 	BoolFalse bool
 	Nil       interface{}
 	Nested    dummyNestedParameter
+	IP1       net.IP
+	IP2       net.IP
+	CIDR1     net.IPNet
+	CIDR2     net.IPNet
 }
 
 func (this dummyParameter) Func() string {
@@ -33,9 +38,9 @@ func (this dummyParameter) FuncArgStr(arg1 string) string {
 }
 
 func (this dummyParameter) TestArgs(str string, ui uint, ui8 uint8, ui16 uint16, ui32 uint32, ui64 uint64, i int, i8 int8, i16 int16, i32 int32, i64 int64, f32 float32, f64 float64, b bool) string {
-	
+
 	var sum float64
-	
+
 	sum = float64(ui) + float64(ui8) + float64(ui16) + float64(ui32) + float64(ui64)
 	sum += float64(i) + float64(i8) + float64(i16) + float64(i32) + float64(i64)
 	sum += float64(f32)
@@ -67,6 +72,10 @@ var dummyParameterInstance = dummyParameter{
 	Nested: dummyNestedParameter{
 		Funk: "funkalicious",
 	},
+	IP1:   net.ParseIP("127.0.0.1"),
+	IP2:   net.ParseIP("127.0.0.3"),
+	CIDR1: mustParseCIDR("127.0.0.4/22"),
+	CIDR2: mustParseCIDR("27.0.0.0/12"),
 }
 
 var fooParameter = EvaluationParameter{
@@ -74,6 +83,16 @@ var fooParameter = EvaluationParameter{
 	Value: dummyParameterInstance,
 }
 
+var fooParameterEmptyIP = EvaluationParameter{
+	Name: "foo",
+	Value: dummyParameter{
+		IP1:   nil,
+		IP2:   nil,
+		CIDR1: mustParseCIDR("127.0.0.4/22"),
+		CIDR2: mustParseCIDR("27.0.0.0/12"),
+	},
+}
+
 var fooPtrParameter = EvaluationParameter{
 	Name:  "fooptr",
 	Value: &dummyParameterInstance,
@@ -83,3 +102,28 @@ var fooFailureParameters = map[string]interface{}{
 	"foo":    fooParameter.Value,
 	"fooptr": &fooPtrParameter.Value,
 }
+
+func mustParseCIDR(cidr string) net.IPNet {
+	_, ipNet, err := net.ParseCIDR(cidr)
+	if err != nil {
+		panic(err)
+	}
+	return *ipNet
+
+}
+
+var CIDRTestFunction = map[string]ExpressionFunction{
+	"InNetwork": func(args ...interface{}) (interface{}, error) {
+		ip, ok1 := args[0].(net.IP)
+		ipNet, ok2 := args[1].(net.IPNet)
+
+		if !ok1 {
+			return nil, fmt.Errorf("variable %s not ip", args[0])
+		}
+		if !ok2 {
+			return nil, fmt.Errorf("variable %s not IPnet its %T", args[1], args[1])
+		}
+
+		return ipNet.Contains(ip), nil
+	},
+}
diff --git a/evaluationFailure_test.go b/evaluationFailure_test.go
index e04f24f..cd813ed 100644
--- a/evaluationFailure_test.go
+++ b/evaluationFailure_test.go
@@ -6,6 +6,7 @@ package govaluate
 import (
 	"errors"
 	"fmt"
+	"net"
 	"strings"
 	"testing"
 )
@@ -43,6 +44,8 @@ var EVALUATION_FAILURE_PARAMETERS = map[string]interface{}{
 	"number": 1,
 	"string": "foo",
 	"bool":   true,
+	"ip":     net.ParseIP("127.0.0.1"),
+	"ipnet":  mustParseCIDR("127.0.0.1/12"),
 }
 
 func TestComplexParameter(test *testing.T) {
@@ -185,6 +188,42 @@ func TestModifierTyping(test *testing.T) {
 			Input:    "number >> bool",
 			Expected: INVALID_MODIFIER_TYPES,
 		},
+		EvaluationFailureTest{
+
+			Name:     "BITWISE_RSHIFT bool to IP",
+			Input:    "bool >> ip",
+			Expected: INVALID_MODIFIER_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "BITWISE_OR string to ip",
+			Input:    "string | ip",
+			Expected: INVALID_MODIFIER_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "BITWISE_RSHIFT bool to ipnet",
+			Input:    "bool >> ipnet",
+			Expected: INVALID_MODIFIER_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "BITWISE_OR string to ipnet",
+			Input:    "string | ipnet",
+			Expected: INVALID_MODIFIER_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "BITWISE_RSHIFT number to ipnet",
+			Input:    "number >> ipnet",
+			Expected: INVALID_MODIFIER_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "BITWISE_OR number to ipnet",
+			Input:    "number | ipnet",
+			Expected: INVALID_MODIFIER_TYPES,
+		},
 	}
 
 	runEvaluationFailureTests(evaluationTests, test)
@@ -241,6 +280,42 @@ func TestLogicalOperatorTyping(test *testing.T) {
 			Input:    "string || bool",
 			Expected: INVALID_LOGICALOP_TYPES,
 		},
+		EvaluationFailureTest{
+
+			Name:     "AND bool to IP",
+			Input:    "bool && ip",
+			Expected: INVALID_LOGICALOP_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "OR string to ip",
+			Input:    "string || ip",
+			Expected: INVALID_LOGICALOP_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "AND bool to ipnet",
+			Input:    "bool && ipnet",
+			Expected: INVALID_LOGICALOP_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "OR string to ipnet",
+			Input:    "string || ipnet",
+			Expected: INVALID_LOGICALOP_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "AND number to ipnet",
+			Input:    "number && ipnet",
+			Expected: INVALID_LOGICALOP_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "OR number to ipnet",
+			Input:    "number || ipnet",
+			Expected: INVALID_LOGICALOP_TYPES,
+		},
 	}
 
 	runEvaluationFailureTests(evaluationTests, test)
@@ -368,6 +443,18 @@ func TestComparatorTyping(test *testing.T) {
 			Input:    "1 in true",
 			Expected: INVALID_COMPARATOR_TYPES,
 		},
+		EvaluationFailureTest{
+
+			Name:     "NREQ bool to ipnet",
+			Input:    "bool !~ ipnet",
+			Expected: INVALID_COMPARATOR_TYPES,
+		},
+		EvaluationFailureTest{
+
+			Name:     "IN string to ipnet",
+			Input:    "string in ipnet",
+			Expected: INVALID_COMPARATOR_TYPES,
+		},
 	}
 
 	runEvaluationFailureTests(evaluationTests, test)
@@ -388,6 +475,12 @@ func TestTernaryTyping(test *testing.T) {
 			Input:    "'foo' ? true",
 			Expected: INVALID_TERNARY_TYPES,
 		},
+		EvaluationFailureTest{
+
+			Name:     "Ternary with ip",
+			Input:    "'foo' ? ip",
+			Expected: INVALID_TERNARY_TYPES,
+		},
 	}
 
 	runEvaluationFailureTests(evaluationTests, test)
diff --git a/evaluationStage.go b/evaluationStage.go
index 11ea587..6ec38a1 100644
--- a/evaluationStage.go
+++ b/evaluationStage.go
@@ -1,9 +1,11 @@
 package govaluate
 
 import (
+	"bytes"
 	"errors"
 	"fmt"
 	"math"
+	"net"
 	"reflect"
 	"regexp"
 	"strings"
@@ -114,24 +116,36 @@ func gteStage(left interface{}, right interface{}, parameters Parameters) (inter
 	if isString(left) && isString(right) {
 		return boolIface(left.(string) >= right.(string)), nil
 	}
+	if isIp(left) && isIp(right) {
+		return boolIface(bytes.Compare(left.(net.IP).To4(), right.(net.IP).To4()) >= 0), nil
+	}
 	return boolIface(left.(float64) >= right.(float64)), nil
 }
 func gtStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) {
 	if isString(left) && isString(right) {
 		return boolIface(left.(string) > right.(string)), nil
 	}
+	if isIp(left) && isIp(right) {
+		return boolIface(bytes.Compare(left.(net.IP).To4(), right.(net.IP).To4()) > 0), nil
+	}
 	return boolIface(left.(float64) > right.(float64)), nil
 }
 func lteStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) {
 	if isString(left) && isString(right) {
 		return boolIface(left.(string) <= right.(string)), nil
 	}
+	if isIp(left) && isIp(right) {
+		return boolIface(bytes.Compare(left.(net.IP).To4(), right.(net.IP).To4()) <= 0), nil
+	}
 	return boolIface(left.(float64) <= right.(float64)), nil
 }
 func ltStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) {
 	if isString(left) && isString(right) {
 		return boolIface(left.(string) < right.(string)), nil
 	}
+	if isIp(left) && isIp(right) {
+		return boolIface(bytes.Compare(left.(net.IP).To4(), right.(net.IP).To4()) == 0), nil
+	}
 	return boolIface(left.(float64) < right.(float64)), nil
 }
 func equalStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) {
@@ -421,7 +435,7 @@ func separatorStage(left interface{}, right interface{}, parameters Parameters)
 func inStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) {
 
 	for _, value := range right.([]interface{}) {
-		if left == value {
+		if reflect.DeepEqual(left, value) {
 			return true, nil
 		}
 	}
@@ -439,6 +453,11 @@ func isString(value interface{}) bool {
 	return false
 }
 
+func isIp(value interface{}) bool {
+	_, ok := value.(net.IP)
+	return ok
+}
+
 func isRegexOrString(value interface{}) bool {
 
 	switch value.(type) {
@@ -493,6 +512,10 @@ func comparatorTypeCheck(left interface{}, right interface{}) bool {
 	if isString(left) && isString(right) {
 		return true
 	}
+	if isIp(left) && isIp(right) {
+		return true
+	}
+
 	return false
 }
 
diff --git a/evaluation_test.go b/evaluation_test.go
index a2b65e8..57af16b 100644
--- a/evaluation_test.go
+++ b/evaluation_test.go
@@ -710,7 +710,7 @@ func TestNoParameterEvaluation(test *testing.T) {
 			Expected: true,
 		},
 		EvaluationTest{
-			
+
 			Name:  "Ternary/Java EL ambiguity",
 			Input: "false ? foo:length()",
 			Functions: map[string]ExpressionFunction{
@@ -1419,6 +1419,87 @@ func TestParameterizedEvaluation(test *testing.T) {
 			Parameters: []EvaluationParameter{fooParameter},
 			Expected:   false,
 		},
+		EvaluationTest{
+			Name:       "IP equal",
+			Input:      "foo.IP1 == foo.IP1",
+			Parameters: []EvaluationParameter{fooParameter},
+			Expected:   true,
+		},
+		EvaluationTest{
+			Name:       "IP not equal",
+			Input:      "foo.IP1 != foo.IP1",
+			Parameters: []EvaluationParameter{fooParameter},
+			Expected:   false,
+		},
+		EvaluationTest{
+			Name:       "IP greater than equal",
+			Input:      "foo.IP2 >= foo.IP1",
+			Parameters: []EvaluationParameter{fooParameter},
+			Expected:   true,
+		},
+		EvaluationTest{
+			Name:       "IP greater than equal false result",
+			Input:      "foo.IP1 >= foo.IP2",
+			Parameters: []EvaluationParameter{fooParameter},
+			Expected:   false,
+		},
+		EvaluationTest{
+			Name:       "IP in",
+			Input:      "foo.IP1 in (foo.IP2, 127.0.0.2, foo.IP1)",
+			Parameters: []EvaluationParameter{fooParameter},
+			Expected:   true,
+		},
+		EvaluationTest{
+			Name:       "IP not in",
+			Input:      "foo.IP1 in (foo.IP2, 127.0.0.2)",
+			Parameters: []EvaluationParameter{fooParameter},
+			Expected:   false,
+		},
+		EvaluationTest{
+			Name:       "CIDR in Network",
+			Input:      "InNetwork(foo.IP1, foo.CIDR1)",
+			Parameters: []EvaluationParameter{fooParameter},
+			Functions:  CIDRTestFunction,
+			Expected:   true,
+		},
+		EvaluationTest{
+			Name:       "CIDR not in Network",
+			Input:      "InNetwork(foo.IP1, foo.CIDR2)",
+			Parameters: []EvaluationParameter{fooParameter},
+			Functions:  CIDRTestFunction,
+			Expected:   false,
+		},
+		EvaluationTest{
+			Name:       "nil IP equal",
+			Input:      "foo.IP1 == 127.0.0.1",
+			Parameters: []EvaluationParameter{fooParameterEmptyIP},
+			Expected:   false,
+		},
+		EvaluationTest{
+			Name:       "nil IP not equal",
+			Input:      "foo.IP1 != 127.0.0.1",
+			Parameters: []EvaluationParameter{fooParameterEmptyIP},
+			Expected:   true,
+		},
+		EvaluationTest{
+			Name:       "nil IP greater than equal false result",
+			Input:      "foo.IP1 >= 127.0.0.1",
+			Parameters: []EvaluationParameter{fooParameterEmptyIP},
+			Expected:   false,
+		},
+		EvaluationTest{
+			Name:       "nil IP in",
+			Input:      "foo.IP1 in (127.0.0.1, 127.0.0.2,  127.0.0.3)",
+			Parameters: []EvaluationParameter{fooParameterEmptyIP},
+			Expected:   false,
+		},
+		EvaluationTest{
+			Name:       "nil CIDR in Network",
+			Input:      "InNetwork(foo.IP1, foo.CIDR1)",
+			Parameters: []EvaluationParameter{fooParameterEmptyIP},
+			Functions:  CIDRTestFunction,
+			Expected:   false,
+		},
 	}
 
 	runEvaluationTests(evaluationTests, test)
diff --git a/lexerState.go b/lexerState.go
index 6726e90..a773e51 100644
--- a/lexerState.go
+++ b/lexerState.go
@@ -32,6 +32,7 @@ var validLexerStates = []lexerState{
 			STRING,
 			TIME,
 			CLAUSE,
+			IP,
 		},
 	},
 
@@ -53,6 +54,7 @@ var validLexerStates = []lexerState{
 			TIME,
 			CLAUSE,
 			CLAUSE_CLOSE,
+			IP,
 		},
 	},
 
@@ -94,8 +96,29 @@ var validLexerStates = []lexerState{
 			SEPARATOR,
 		},
 	},
+
 	lexerState{
+		kind:       IP,
+		isEOF:      true,
+		isNullable: false,
+		validNextKinds: []TokenKind{
+			SEPARATOR,
+			COMPARATOR,
+			CLAUSE_CLOSE,
+			LOGICALOP,
+		},
+	},
 
+	lexerState{
+		kind:       IPNET,
+		isEOF:      true,
+		isNullable: false,
+		validNextKinds: []TokenKind{
+			CLAUSE_CLOSE,
+		},
+	},
+
+	lexerState{
 		kind:       BOOLEAN,
 		isEOF:      true,
 		isNullable: false,
@@ -203,6 +226,7 @@ var validLexerStates = []lexerState{
 			CLAUSE,
 			CLAUSE_CLOSE,
 			PATTERN,
+			IP,
 		},
 	},
 	lexerState{
@@ -300,6 +324,8 @@ var validLexerStates = []lexerState{
 			FUNCTION,
 			ACCESSOR,
 			CLAUSE,
+			IP,
+			IPNET,
 		},
 	},
 }
diff --git a/parsing.go b/parsing.go
index 40c7ed2..75e29d1 100644
--- a/parsing.go
+++ b/parsing.go
@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"errors"
 	"fmt"
+	"net"
 	"regexp"
 	"strconv"
 	"strings"
@@ -106,13 +107,34 @@ func readToken(stream *lexerStream, state lexerState, functions map[string]Expre
 
 			tokenString = readTokenUntilFalse(stream, isNumeric)
 			tokenValue, err = strconv.ParseFloat(tokenString, 64)
+			if err == nil {
+				kind = NUMERIC
+				break
+			}
 
-			if err != nil {
-				errorMsg := fmt.Sprintf("Unable to parse numeric value '%v' to float64\n", tokenString)
-				return ExpressionToken{}, errors.New(errorMsg), false
+			tokenValue = net.ParseIP(tokenString)
+
+			if tokenValue.(net.IP) != nil {
+				tokenValue.(net.IP).String()
+				//maybe it`s CIDR?
+				if stream.canRead() {
+					stream.readCharacter()
+					tokenTail := readTokenUntilFalse(stream, isCIDRTail)
+
+					if _, ipNet, err := net.ParseCIDR(tokenString + tokenTail); err == nil {
+						kind = IPNET
+						tokenValue = *ipNet
+						break
+					}
+				}
+
+				kind = IP
+				break
 			}
-			kind = NUMERIC
-			break
+
+			errorMsg := fmt.Sprintf("Unable to parse numeric value '%v' to float64, ip cidr\n", tokenString)
+			return ExpressionToken{}, errors.New(errorMsg), false
+
 		}
 
 		// comma, separator
@@ -414,10 +436,6 @@ func checkBalance(tokens []ExpressionToken) error {
 	return nil
 }
 
-func isDigit(character rune) bool {
-	return unicode.IsDigit(character)
-}
-
 func isHexDigit(character rune) bool {
 
 	character = unicode.ToLower(character)
@@ -436,6 +454,11 @@ func isNumeric(character rune) bool {
 	return unicode.IsDigit(character) || character == '.'
 }
 
+func isCIDRTail(character rune) bool {
+
+	return unicode.IsDigit(character) || character == '/'
+}
+
 func isNotQuote(character rune) bool {
 
 	return character != '\'' && character != '"'
diff --git a/parsingFailure_test.go b/parsingFailure_test.go
index d8a3184..0c2ae74 100644
--- a/parsingFailure_test.go
+++ b/parsingFailure_test.go
@@ -161,9 +161,8 @@ func TestParsingFailure(test *testing.T) {
 			Expected: UNBALANCED_PARENTHESIS,
 		},
 		ParsingFailureTest{
-
 			Name:     "Multiple radix",
-			Input:    "127.0.0.1",
+			Input:    "127.0.0.1.1.1.1.",
 			Expected: INVALID_NUMERIC,
 		},
 		ParsingFailureTest{
diff --git a/parsing_test.go b/parsing_test.go
index d57b809..0e58393 100644
--- a/parsing_test.go
+++ b/parsing_test.go
@@ -3,6 +3,7 @@ package govaluate
 import (
 	"bytes"
 	"fmt"
+	"net"
 	"reflect"
 	"testing"
 	"time"
@@ -449,6 +450,17 @@ func TestConstantParsing(test *testing.T) {
 				},
 			},
 		},
+		TokenParsingTest{
+
+			Name:  "IP",
+			Input: "127.0.0.1",
+			Expected: []ExpressionToken{
+				ExpressionToken{
+					Kind:  IP,
+					Value: net.ParseIP("127.0.0.1"),
+				},
+			},
+		},
 	}
 
 	tokenParsingTests = combineWhitespaceExpressions(tokenParsingTests)
@@ -1633,7 +1645,7 @@ func runTokenParsingTest(tokenParsingTests []TokenParsingTest, test *testing.T)
 				continue
 			}
 
-			// gotta be an accessor
+			// gotta be an accessor or IP
 			if reflectedKind == reflect.Slice {
 
 				if actualToken.Value == nil {
@@ -1641,17 +1653,28 @@ func runTokenParsingTest(tokenParsingTests []TokenParsingTest, test *testing.T)
 					test.Logf("Expected token value '%v' does not match nil", expectedToken.Value)
 					test.Fail()
 				}
+				_, ok := actualToken.Value.([]string)
+				if ok {
+					for z, actual := range actualToken.Value.([]string) {
 
-				for z, actual := range actualToken.Value.([]string) {
-
-					if actual != expectedToken.Value.([]string)[z] {
+						if actual != expectedToken.Value.([]string)[z] {
 
+							test.Logf("Test '%s' failed:", parsingTest.Name)
+							test.Logf("Expected token value '%v' does not match '%v'", expectedToken.Value, actualToken.Value)
+							test.Fail()
+						}
+					}
+					continue
+				}
+				_, ok = actualToken.Value.(net.IP)
+				if ok {
+					if bytes.Compare(actualToken.Value.(net.IP).To4(), expectedToken.Value.(net.IP).To4()) != 0 {
 						test.Logf("Test '%s' failed:", parsingTest.Name)
 						test.Logf("Expected token value '%v' does not match '%v'", expectedToken.Value, actualToken.Value)
 						test.Fail()
 					}
+					continue
 				}
-				continue
 			}
 
 			if actualToken.Value != expectedToken.Value {
diff --git a/stagePlanner.go b/stagePlanner.go
index d71ed12..150221a 100644
--- a/stagePlanner.go
+++ b/stagePlanner.go
@@ -421,6 +421,10 @@ func planValue(stream *tokenStream) (*evaluationStage, error) {
 		fallthrough
 	case PATTERN:
 		fallthrough
+	case IP:
+		fallthrough
+	case IPNET:
+		fallthrough
 	case BOOLEAN:
 		symbol = LITERAL
 		operator = makeLiteralStage(token.Value)
diff --git a/tokenKind_test.go b/tokenKind_test.go
index 8277a95..942f207 100644
--- a/tokenKind_test.go
+++ b/tokenKind_test.go
@@ -28,6 +28,8 @@ func TestTokenKindStrings(test *testing.T) {
 		CLAUSE,
 		CLAUSE_CLOSE,
 		TERNARY,
+		IP,
+		IPNET,
 	}
 
 	for _, kind := range kinds {