diff --git a/lib/src/main/java/com/ledmington/gal/Utils.java b/lib/src/main/java/com/ledmington/gal/Utils.java index ed42745..8945d5a 100644 --- a/lib/src/main/java/com/ledmington/gal/Utils.java +++ b/lib/src/main/java/com/ledmington/gal/Utils.java @@ -36,25 +36,32 @@ public static Supplier weightedChoose( throw new IllegalArgumentException("The list of values cannot be empty"); } - final Function safeWeight = x -> { - final double result = weight.apply(x); - if (result < 0.0) { - throw new IllegalArgumentException(String.format( - "Negative weights are not allowed: the object '%s' produced the weight %f", - x.toString(), result)); - } - return result; - }; - final double totalWeight = - values.stream().mapToDouble(safeWeight::apply).sum(); + double minWeight = Double.POSITIVE_INFINITY; + double maxWeight = Double.NEGATIVE_INFINITY; + double totalWeight = 0.0; + for (final X x : values) { + final double w = weight.apply(x); + minWeight = Math.min(minWeight, w); + maxWeight = Math.max(maxWeight, w); + totalWeight += w; + } + + if (minWeight == maxWeight) { + // if they all have the same weight, return a special function which treats all + // values equally + return () -> values.get(rng.nextInt(0, values.size())); + } + + final double finalMinWeight = minWeight; + final double finalTotalWeight = totalWeight - finalMinWeight * values.size(); return () -> { - final double chosenWeight = rng.nextDouble(0.0, totalWeight); + final double chosenWeight = rng.nextDouble(0.0, finalTotalWeight); double sum = 0.0; for (int i = 0; i < values.size() - 1; i++) { final X ith_element = values.get(i); - sum += safeWeight.apply(ith_element); + sum += (weight.apply(ith_element) - finalMinWeight); if (sum >= chosenWeight) { return ith_element; } diff --git a/lib/src/test/java/com/ledmington/gal/TestUtils.java b/lib/src/test/java/com/ledmington/gal/TestUtils.java index 27880c7..38f2157 100644 --- a/lib/src/test/java/com/ledmington/gal/TestUtils.java +++ b/lib/src/test/java/com/ledmington/gal/TestUtils.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.function.Function; +import java.util.function.Supplier; import java.util.random.RandomGenerator; import java.util.random.RandomGeneratorFactory; import java.util.stream.IntStream; @@ -86,8 +87,9 @@ public void weightsWork() { count.put(x, 0); } + final Supplier weightedChoose = Utils.weightedChoose(values, w, rng); for (int i = 0; i < 10_000; i++) { - final Integer chosen = Utils.weightedChoose(values, w, rng).get(); + final Integer chosen = weightedChoose.get(); count.put(chosen, count.get(chosen) + 1); } @@ -106,11 +108,33 @@ public void weightsWork() { } @Test - public void negativeWeightsDoNotWork() { + public void negativeWeightsWorkAsWell() { final List values = List.of(1, 2, 3, 4, 5, 6, 7, 8, 9); final Function w = x -> -(double) x; - assertThrows(IllegalArgumentException.class, () -> Utils.weightedChoose(values, w, rng)); + final Map count = new HashMap<>(); + for (final Integer x : values) { + count.put(x, 0); + } + + final Supplier weightedChoose = Utils.weightedChoose(values, w, rng); + for (int i = 0; i < 10_000; i++) { + final Integer chosen = weightedChoose.get(); + count.put(chosen, count.get(chosen) + 1); + } + + for (int i = 0; i < values.size() - 1; i++) { + final Integer first = values.get(i); + final Integer second = values.get(i + 1); + assertTrue( + count.get(first) > 0, + String.format("Value %d (with weight %f) did not appear once", first, w.apply(first))); + assertTrue( + count.get(first) < count.get(second), + String.format( + "Value %d (with weight %f) appeared more often than value %d (with weight %f): %,d > %,d", + first, w.apply(first), second, w.apply(second), count.get(first), count.get(second))); + } } @Test