From c4711a9bda44e926104c3801ad87492753ed9ecd Mon Sep 17 00:00:00 2001 From: Toshiya Kobayashi Date: Fri, 13 Dec 2024 18:05:42 +0900 Subject: [PATCH] =?UTF-8?q?[incubator-kie-drools-6180]=20accumulate=20min?= =?UTF-8?q?=20doesn't=20evaluate=20correctly=E2=80=A6=20(#6186)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [incubator-kie-drools-6180] accumulate min doesn't evaluate correctly with more than 18 digits BigDecimal * fixed assertion --- .../rule/builder/util/AccumulateUtil.java | 8 + .../BigDecimalMaxAccumulateFunction.java | 100 +++++++++++ .../BigDecimalMinAccumulateFunction.java | 98 +++++++++++ .../BigIntegerMaxAccumulateFunction.java | 100 +++++++++++ .../BigIntegerMinAccumulateFunction.java | 98 +++++++++++ .../META-INF/kie.default.properties.conf | 4 + .../integrationtests/AccumulateTest.java | 156 ++++++++++++++++++ 7 files changed, 564 insertions(+) create mode 100644 drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMaxAccumulateFunction.java create mode 100644 drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMinAccumulateFunction.java create mode 100644 drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMaxAccumulateFunction.java create mode 100644 drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMinAccumulateFunction.java diff --git a/drools-compiler/src/main/java/org/drools/compiler/rule/builder/util/AccumulateUtil.java b/drools-compiler/src/main/java/org/drools/compiler/rule/builder/util/AccumulateUtil.java index 192873909f1..9fc2564123a 100644 --- a/drools-compiler/src/main/java/org/drools/compiler/rule/builder/util/AccumulateUtil.java +++ b/drools-compiler/src/main/java/org/drools/compiler/rule/builder/util/AccumulateUtil.java @@ -55,6 +55,10 @@ public static String getFunctionName(Supplier> exprClassSupplier, Strin functionName = "maxI"; } else if (exprClass == Long.class) { functionName = "maxL"; + } else if (exprClass == BigInteger.class) { + functionName = "maxBI"; + } else if (exprClass == BigDecimal.class) { + functionName = "maxBD"; } else if (Number.class.isAssignableFrom( exprClass )) { functionName = "maxN"; } @@ -64,6 +68,10 @@ public static String getFunctionName(Supplier> exprClassSupplier, Strin functionName = "minI"; } else if (exprClass == Long.class) { functionName = "minL"; + } else if (exprClass == BigInteger.class) { + functionName = "minBI"; + } else if (exprClass == BigDecimal.class) { + functionName = "minBD"; } else if (Number.class.isAssignableFrom( exprClass )) { functionName = "minN"; } diff --git a/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMaxAccumulateFunction.java b/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMaxAccumulateFunction.java new file mode 100644 index 00000000000..884fa35a822 --- /dev/null +++ b/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMaxAccumulateFunction.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.drools.core.base.accumulators; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.math.BigDecimal; + +/** + * An implementation of an accumulator capable of calculating maximum values + */ +public class BigDecimalMaxAccumulateFunction extends AbstractAccumulateFunction { + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + + } + + public void writeExternal(ObjectOutput out) throws IOException { + + } + + protected static class MaxData implements Externalizable { + public BigDecimal max = null; + + public MaxData() {} + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + max = (BigDecimal) in.readObject(); + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(max); + } + + @Override + public String toString() { + return "max"; + } + } + + public MaxData createContext() { + return new MaxData(); + } + + public void init(MaxData data) { + data.max = null; + } + + public void accumulate(MaxData data, + Object value) { + if (value != null) { + BigDecimal bdValue = (BigDecimal) value; + data.max = data.max == null || data.max.compareTo(bdValue) < 0 ? + bdValue : + data.max; + } + } + + public void reverse(MaxData data, + Object value) { + } + + @Override + public boolean tryReverse( MaxData data, Object value ) { + if (value != null) { + return data.max.compareTo((BigDecimal) value) > 0; + } + return true; + } + + public Object getResult(MaxData data) { + return data.max; + } + + public boolean supportsReverse() { + return false; + } + + public Class getResultType() { + return BigDecimal.class; + } +} diff --git a/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMinAccumulateFunction.java b/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMinAccumulateFunction.java new file mode 100644 index 00000000000..c3f0a628315 --- /dev/null +++ b/drools-core/src/main/java/org/drools/core/base/accumulators/BigDecimalMinAccumulateFunction.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.drools.core.base.accumulators; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.math.BigDecimal; + +/** + * An implementation of an accumulator capable of calculating minimum values + */ +public class BigDecimalMinAccumulateFunction extends AbstractAccumulateFunction { + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + } + + public void writeExternal(ObjectOutput out) throws IOException { + } + + protected static class MinData implements Externalizable { + public BigDecimal min = null; + + public MinData() {} + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + min = (BigDecimal) in.readObject(); + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(min); + } + + @Override + public String toString() { + return "min"; + } + } + + public MinData createContext() { + return new MinData(); + } + + public void init(MinData data) { + data.min = null; + } + + public void accumulate(MinData data, + Object value) { + if (value != null) { + BigDecimal bdValue = (BigDecimal) value; + data.min = data.min == null || data.min.compareTo(bdValue) > 0 ? + bdValue : + data.min; + } + } + + @Override + public boolean tryReverse( MinData data, Object value ) { + if (value != null) { + return data.min.compareTo((BigDecimal) value) < 0; + } + return true; + } + + public void reverse(MinData data, + Object value) { + } + + public Object getResult(MinData data) { + return data.min; + } + + public boolean supportsReverse() { + return false; + } + + public Class getResultType() { + return BigDecimal.class; + } +} diff --git a/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMaxAccumulateFunction.java b/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMaxAccumulateFunction.java new file mode 100644 index 00000000000..2fee047f05d --- /dev/null +++ b/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMaxAccumulateFunction.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.drools.core.base.accumulators; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.math.BigInteger; + +/** + * An implementation of an accumulator capable of calculating maximum values + */ +public class BigIntegerMaxAccumulateFunction extends AbstractAccumulateFunction { + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + + } + + public void writeExternal(ObjectOutput out) throws IOException { + + } + + protected static class MaxData implements Externalizable { + public BigInteger max = null; + + public MaxData() {} + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + max = (BigInteger) in.readObject(); + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(max); + } + + @Override + public String toString() { + return "max"; + } + } + + public MaxData createContext() { + return new MaxData(); + } + + public void init(MaxData data) { + data.max = null; + } + + public void accumulate(MaxData data, + Object value) { + if (value != null) { + BigInteger biValue = (BigInteger) value; + data.max = data.max == null || data.max.compareTo(biValue) < 0 ? + biValue : + data.max; + } + } + + public void reverse(MaxData data, + Object value) { + } + + @Override + public boolean tryReverse( MaxData data, Object value ) { + if (value != null) { + return data.max.compareTo((BigInteger) value) > 0; + } + return true; + } + + public Object getResult(MaxData data) { + return data.max; + } + + public boolean supportsReverse() { + return false; + } + + public Class getResultType() { + return BigInteger.class; + } +} diff --git a/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMinAccumulateFunction.java b/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMinAccumulateFunction.java new file mode 100644 index 00000000000..f292e73afff --- /dev/null +++ b/drools-core/src/main/java/org/drools/core/base/accumulators/BigIntegerMinAccumulateFunction.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.drools.core.base.accumulators; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.math.BigInteger; + +/** + * An implementation of an accumulator capable of calculating minimum values + */ +public class BigIntegerMinAccumulateFunction extends AbstractAccumulateFunction { + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + } + + public void writeExternal(ObjectOutput out) throws IOException { + } + + protected static class MinData implements Externalizable { + public BigInteger min = null; + + public MinData() {} + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + min = (BigInteger) in.readObject(); + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(min); + } + + @Override + public String toString() { + return "min"; + } + } + + public MinData createContext() { + return new MinData(); + } + + public void init(MinData data) { + data.min = null; + } + + public void accumulate(MinData data, + Object value) { + if (value != null) { + BigInteger biValue = (BigInteger) value; + data.min = data.min == null || data.min.compareTo(biValue) > 0 ? + biValue : + data.min; + } + } + + @Override + public boolean tryReverse( MinData data, Object value ) { + if (value != null) { + return data.min.compareTo((BigInteger) value) < 0; + } + return true; + } + + public void reverse(MinData data, + Object value) { + } + + public Object getResult(MinData data) { + return data.min; + } + + public boolean supportsReverse() { + return false; + } + + public Class getResultType() { + return BigInteger.class; + } +} diff --git a/drools-core/src/main/resources/META-INF/kie.default.properties.conf b/drools-core/src/main/resources/META-INF/kie.default.properties.conf index f3fcd0b2e69..823723c7e17 100644 --- a/drools-core/src/main/resources/META-INF/kie.default.properties.conf +++ b/drools-core/src/main/resources/META-INF/kie.default.properties.conf @@ -44,10 +44,14 @@ drools.accumulate.function.max = org.drools.core.base.accumulators.MaxAccumulate drools.accumulate.function.maxN = org.drools.core.base.accumulators.NumericMaxAccumulateFunction drools.accumulate.function.maxI = org.drools.core.base.accumulators.IntegerMaxAccumulateFunction drools.accumulate.function.maxL = org.drools.core.base.accumulators.LongMaxAccumulateFunction +drools.accumulate.function.maxBI = org.drools.core.base.accumulators.BigIntegerMaxAccumulateFunction +drools.accumulate.function.maxBD = org.drools.core.base.accumulators.BigDecimalMaxAccumulateFunction drools.accumulate.function.min = org.drools.core.base.accumulators.MinAccumulateFunction drools.accumulate.function.minN = org.drools.core.base.accumulators.NumericMinAccumulateFunction drools.accumulate.function.minI = org.drools.core.base.accumulators.IntegerMinAccumulateFunction drools.accumulate.function.minL = org.drools.core.base.accumulators.LongMinAccumulateFunction +drools.accumulate.function.minBI = org.drools.core.base.accumulators.BigIntegerMinAccumulateFunction +drools.accumulate.function.minBD = org.drools.core.base.accumulators.BigDecimalMinAccumulateFunction drools.accumulate.function.count = org.drools.core.base.accumulators.CountAccumulateFunction drools.accumulate.function.collectList = org.drools.core.base.accumulators.CollectListAccumulateFunction drools.accumulate.function.collectSet = org.drools.core.base.accumulators.CollectSetAccumulateFunction diff --git a/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/AccumulateTest.java b/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/AccumulateTest.java index 53b1c746f18..4213446b7b6 100644 --- a/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/AccumulateTest.java +++ b/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/AccumulateTest.java @@ -22,6 +22,7 @@ import java.io.ObjectOutput; import java.io.Serializable; import java.math.BigDecimal; +import java.math.BigInteger; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -37,6 +38,7 @@ import org.drools.core.RuleSessionConfiguration; import org.drools.commands.runtime.rule.InsertElementsCommand; import org.drools.kiesession.rulebase.InternalKnowledgeBase; +import org.drools.mvel.compiler.Primitives; import org.drools.testcoverage.common.model.Cheese; import org.drools.testcoverage.common.model.Cheesery; import org.drools.testcoverage.common.model.Order; @@ -3941,4 +3943,158 @@ public void testPeerCollectWithEager(KieBaseTestConfiguration kieBaseTestConfigu kieSession.dispose(); } } + + @ParameterizedTest(name = "KieBase type={0}") + @MethodSource("parameters") + void minWithBigDecimalHighAccuracy(KieBaseTestConfiguration kieBaseTestConfiguration) { + final String drl = + "import " + Primitives.class.getCanonicalName() + ";\n" + + "global java.util.List results;\n" + + "rule R1 when\n" + + " accumulate(Primitives($bd : bigDecimal), $min : min($bd))\n" + + "then\n" + + " results.add($min);\n" + + " results.add($min.scale());\n" + // BigDecimal method + "end"; + + final KieBase kieBase = KieBaseUtil.getKieBaseFromKieModuleFromDrl("accumulate-test", kieBaseTestConfiguration, drl); + final KieSession kieSession = kieBase.newKieSession(); + try { + List results = new ArrayList<>(); + kieSession.setGlobal("results", results); + Primitives p1 = new Primitives(); + p1.setBigDecimal(new BigDecimal("2024043020240501130000")); + Primitives p2_smallest = new Primitives(); + p2_smallest.setBigDecimal(new BigDecimal("2024043020240501120000")); + Primitives p3 = new Primitives(); + p3.setBigDecimal(new BigDecimal("2024043020240501150000")); + + kieSession.insert(p1); + kieSession.insert(p2_smallest); + kieSession.insert(p3); + kieSession.fireAllRules(); + assertThat(results).hasSize(2); + assertThat(results.get(0)).isEqualTo(p2_smallest.getBigDecimal()); + assertThat(results.get(1)).isEqualTo(0); + } finally { + kieSession.dispose(); + } + } + + @ParameterizedTest(name = "KieBase type={0}") + @MethodSource("parameters") + void minWithBigIntegerHighAccuracy(KieBaseTestConfiguration kieBaseTestConfiguration) { + final String drl = + "import " + Primitives.class.getCanonicalName() + ";\n" + + "global java.util.List results;\n" + + "rule R1 when\n" + + " accumulate(Primitives($bi : bigInteger), $min : min($bi))\n" + + "then\n" + + " results.add($min);\n" + + " results.add($min.nextProbablePrime());\n" + // BigInteger method + "end"; + + final KieBase kieBase = KieBaseUtil.getKieBaseFromKieModuleFromDrl("accumulate-test", kieBaseTestConfiguration, drl); + final KieSession kieSession = kieBase.newKieSession(); + try { + List results = new ArrayList<>(); + kieSession.setGlobal("results", results); + Primitives p1 = new Primitives(); + p1.setBigInteger(new BigInteger("2024043020240501130000")); + Primitives p2_smallest = new Primitives(); + p2_smallest.setBigInteger(new BigInteger("2024043020240501120000")); + Primitives p3 = new Primitives(); + p3.setBigInteger(new BigInteger("2024043020240501150000")); + + kieSession.insert(p1); + kieSession.insert(p2_smallest); + kieSession.insert(p3); + kieSession.fireAllRules(); + assertThat(results).hasSize(2); + assertThat(results.get(0)).isEqualTo(p2_smallest.getBigInteger()); + + // nextProbablePrime value is not important. + // Just to make sure it doesn't raise a compilation error. + assertThat(results.get(1)).isNotNull(); + } finally { + kieSession.dispose(); + } + } + + @ParameterizedTest(name = "KieBase type={0}") + @MethodSource("parameters") + void maxWithBigDecimalHighAccuracy(KieBaseTestConfiguration kieBaseTestConfiguration) { + final String drl = + "import " + Primitives.class.getCanonicalName() + ";\n" + + "global java.util.List results;\n" + + "rule R1 when\n" + + " accumulate(Primitives($bd : bigDecimal), $max : max($bd))\n" + + "then\n" + + " results.add($max);\n" + + " results.add($max.scale());\n" + // BigDecimal method + "end"; + + final KieBase kieBase = KieBaseUtil.getKieBaseFromKieModuleFromDrl("accumulate-test", kieBaseTestConfiguration, drl); + final KieSession kieSession = kieBase.newKieSession(); + try { + List results = new ArrayList<>(); + kieSession.setGlobal("results", results); + Primitives p1 = new Primitives(); + p1.setBigDecimal(new BigDecimal("2024043020240501130000")); + Primitives p2_largest = new Primitives(); + p2_largest.setBigDecimal(new BigDecimal("2024043020240501150000")); + Primitives p3 = new Primitives(); + p3.setBigDecimal(new BigDecimal("2024043020240501120000")); + + kieSession.insert(p1); + kieSession.insert(p2_largest); + kieSession.insert(p3); + kieSession.fireAllRules(); + assertThat(results).hasSize(2); + assertThat(results.get(0)).isEqualTo(p2_largest.getBigDecimal()); + assertThat(results.get(1)).isEqualTo(0); + } finally { + kieSession.dispose(); + } + } + + @ParameterizedTest(name = "KieBase type={0}") + @MethodSource("parameters") + void maxWithBigIntegerHighAccuracy(KieBaseTestConfiguration kieBaseTestConfiguration) { + final String drl = + "import " + Primitives.class.getCanonicalName() + ";\n" + + "global java.util.List results;\n" + + "rule R1 when\n" + + " accumulate(Primitives($bi : bigInteger), $max : max($bi))\n" + + "then\n" + + " results.add($max);\n" + + " results.add($max.nextProbablePrime());\n" + // BigInteger method + "end"; + + final KieBase kieBase = KieBaseUtil.getKieBaseFromKieModuleFromDrl("accumulate-test", kieBaseTestConfiguration, drl); + final KieSession kieSession = kieBase.newKieSession(); + try { + List results = new ArrayList<>(); + kieSession.setGlobal("results", results); + Primitives p1 = new Primitives(); + p1.setBigInteger(new BigInteger("2024043020240501130000")); + Primitives p2_largest = new Primitives(); + p2_largest.setBigInteger(new BigInteger("2024043020240501150000")); + Primitives p3 = new Primitives(); + p3.setBigInteger(new BigInteger("2024043020240501120000")); + + kieSession.insert(p1); + kieSession.insert(p2_largest); + kieSession.insert(p3); + kieSession.fireAllRules(); + assertThat(results).hasSize(2); + assertThat(results.get(0)).isEqualTo(p2_largest.getBigInteger()); + + // nextProbablePrime value is not important. + // Just to make sure it doesn't raise a compilation error. + assertThat(results.get(1)).isNotNull(); + } finally { + kieSession.dispose(); + } + } }