Skip to content

Commit

Permalink
bugfix for abstract model enumeration
Browse files Browse the repository at this point in the history
  • Loading branch information
SHildebrandt committed Sep 23, 2024
1 parent 2a72b68 commit 51767e4
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import com.booleworks.logicng.formulas.FormulaFactory;
import com.booleworks.logicng.formulas.Variable;
import com.booleworks.logicng.handlers.ComputationHandler;
import com.booleworks.logicng.handlers.LngResult;
import com.booleworks.logicng.handlers.events.EnumerationFoundModelsEvent;
import com.booleworks.logicng.handlers.events.LngEvent;
import com.booleworks.logicng.solvers.SatSolver;
Expand Down Expand Up @@ -139,14 +140,14 @@ public LngEvent rollback(final ComputationHandler handler) {
}

@Override
public List<Model> rollbackAndReturnModels(final SatSolver solver, final ComputationHandler handler) {
public LngResult<List<Model>> rollbackAndReturnModels(final SatSolver solver, final ComputationHandler handler) {
final List<Model> modelsToReturn = new ArrayList<>(uncommittedModels.size());
for (int i = 0; i < uncommittedModels.size(); i++) {
modelsToReturn.add(new Model(solver.getUnderlyingSolver().convertInternalModel(uncommittedModels.get(i),
uncommittedIndices.get(i))));
}
rollback(handler);
return modelsToReturn;
final LngEvent cancelCause = rollback(handler);
return cancelCause == null ? LngResult.of(modelsToReturn) : LngResult.canceled(cancelCause);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.booleworks.logicng.formulas.Literal;
import com.booleworks.logicng.formulas.Variable;
import com.booleworks.logicng.handlers.ComputationHandler;
import com.booleworks.logicng.handlers.LngResult;
import com.booleworks.logicng.handlers.events.EnumerationFoundModelsEvent;
import com.booleworks.logicng.handlers.events.LngEvent;
import com.booleworks.logicng.solvers.SatSolver;
Expand Down Expand Up @@ -174,10 +175,10 @@ public LngEvent rollback(final ComputationHandler handler) {
}

@Override
public List<Model> rollbackAndReturnModels(final SatSolver solver, final ComputationHandler handler) {
public LngResult<List<Model>> rollbackAndReturnModels(final SatSolver solver, final ComputationHandler handler) {
final List<Model> modelsToReturn = uncommittedModels.stream().map(Model::new).collect(Collectors.toList());
rollback(handler);
return modelsToReturn;
final LngEvent cancelCause = rollback(handler);
return cancelCause == null ? LngResult.of(modelsToReturn) : LngResult.canceled(cancelCause);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.booleworks.logicng.formulas.Literal;
import com.booleworks.logicng.formulas.Variable;
import com.booleworks.logicng.handlers.ComputationHandler;
import com.booleworks.logicng.handlers.LngResult;
import com.booleworks.logicng.handlers.events.EnumerationFoundModelsEvent;
import com.booleworks.logicng.handlers.events.LngEvent;
import com.booleworks.logicng.knowledgecompilation.bdds.Bdd;
Expand Down Expand Up @@ -164,10 +165,10 @@ public LngEvent rollback(final ComputationHandler handler) {
}

@Override
public List<Model> rollbackAndReturnModels(final SatSolver solver, final ComputationHandler handler) {
public LngResult<List<Model>> rollbackAndReturnModels(final SatSolver solver, final ComputationHandler handler) {
final List<Model> modelsToReturn = new ArrayList<>(uncommittedModels);
rollback(handler);
return modelsToReturn;
final LngEvent cancelCause = rollback(handler);
return cancelCause == null ? LngResult.of(modelsToReturn) : LngResult.canceled(cancelCause);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,14 @@ private LngEvent enumerateRecursive(final EnumerationCollector<RESULT> collector
remainingVars.remove(literal.variable());
}

final List<Model> newSplitAssignments = collector.rollbackAndReturnModels(solver, handler);
final LngResult<List<Model>> newSplitResult = collector.rollbackAndReturnModels(solver, handler);
if (!newSplitResult.isSuccess()) {
solver.loadState(state);
return newSplitResult.getCancelCause();
}
final SortedSet<Variable> recursiveSplitVars =
strategy.splitVarsForRecursionDepth(remainingVars, solver, recursionDepth + 1);
for (final Model newSplitAssignment : newSplitAssignments) {
for (final Model newSplitAssignment : newSplitResult.getPartialResult()) {
final SortedSet<Literal> recursiveSplitModel = new TreeSet<>(newSplitAssignment.getLiterals());
recursiveSplitModel.addAll(splitModel);
enumerateRecursive(collector, solver, recursiveSplitModel, enumerationVars, recursiveSplitVars,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.booleworks.logicng.collections.LngIntVector;
import com.booleworks.logicng.datastructures.Model;
import com.booleworks.logicng.handlers.ComputationHandler;
import com.booleworks.logicng.handlers.LngResult;
import com.booleworks.logicng.handlers.events.LngEvent;
import com.booleworks.logicng.solvers.SatSolver;

Expand Down Expand Up @@ -37,7 +38,7 @@ public interface EnumerationCollector<RESULT> {
* @param relevantAllIndices the relevant indices
* @param handler the model enumeration handler
* @return an event if the handler canceled the computation,
* otherwise {@code null}
* otherwise {@code null}
*/
LngEvent addModel(LngBooleanVector modelFromSolver, SatSolver solver, LngIntVector relevantAllIndices,
ComputationHandler handler);
Expand All @@ -49,7 +50,7 @@ LngEvent addModel(LngBooleanVector modelFromSolver, SatSolver solver, LngIntVect
* Calls the {@code commit()} routine of {@code handler}.
* @param handler the computation handler
* @return an event if the handler canceled the computation,
* otherwise {@code null}
* otherwise {@code null}
*/
LngEvent commit(ComputationHandler handler);

Expand All @@ -60,7 +61,7 @@ LngEvent addModel(LngBooleanVector modelFromSolver, SatSolver solver, LngIntVect
* cancels the computation.
* @param handler the computation handler
* @return an event if the handler canceled the computation,
* otherwise {@code null}
* otherwise {@code null}
*/
LngEvent rollback(ComputationHandler handler);

Expand All @@ -70,9 +71,10 @@ LngEvent addModel(LngBooleanVector modelFromSolver, SatSolver solver, LngIntVect
* Calls the {@code rollback} routine of {@code handler}.
* @param solver solver used for the enumeration
* @param handler the computation handler
* @return list of all discarded models
* @return the LNG result with a list of all discarded models or a canceled
* result if the rollback was canceled by the handler
*/
List<Model> rollbackAndReturnModels(final SatSolver solver, ComputationHandler handler);
LngResult<List<Model>> rollbackAndReturnModels(final SatSolver solver, ComputationHandler handler);

/**
* Returns the currently committed state of the collector.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.booleworks.logicng.handlers;

import com.booleworks.logicng.handlers.events.LngEvent;

public class CallLimitComputationHandler implements ComputationHandler {
private final int n;
private int count;

public CallLimitComputationHandler(final int n) {
this.n = n;
count = 0;
}

@Override
public boolean shouldResume(final LngEvent event) {
return count++ < n;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,8 @@ public void testTimeoutHandlerFixedEnd() {

final LngResult<List<Model>> result = me.apply(solver, handler);

if (result.isSuccess()) {
System.out.println(result.getResult());
}
assertThat(result.isSuccess()).isFalse();
assertThatThrownBy(result::getResult).isInstanceOf(IllegalStateException.class);
}
}

// TODO test partial results (does not seem to work well with negated Pigeon Hole)
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@
package com.booleworks.logicng.knowledgecompilation.bdds;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.booleworks.logicng.LongRunningTag;
import com.booleworks.logicng.datastructures.Model;
import com.booleworks.logicng.formulas.CType;
import com.booleworks.logicng.formulas.Formula;
import com.booleworks.logicng.formulas.FormulaFactory;
import com.booleworks.logicng.formulas.Variable;
import com.booleworks.logicng.handlers.CallLimitComputationHandler;
import com.booleworks.logicng.handlers.LngResult;
import com.booleworks.logicng.knowledgecompilation.bdds.jbuddy.BddKernel;
import com.booleworks.logicng.solvers.SatSolver;
import com.booleworks.logicng.solvers.functions.ModelCountingFunction;
import com.booleworks.logicng.testutils.NQueensGenerator;
import com.booleworks.logicng.testutils.PigeonHoleGenerator;
import org.junit.jupiter.api.Test;

import java.math.BigInteger;
Expand Down Expand Up @@ -106,6 +113,20 @@ public void testAmo() {
assertThat(bdd.enumerateAllModels(generateVariables(f, 100))).hasSize(101);
}

@Test
@LongRunningTag
public void testComputationHandlerExitPoints() {
final Formula formula = new PigeonHoleGenerator(f).generate(10).negate(f);
final SatSolver solver = SatSolver.newSolver(f);
solver.add(formula);
for (int callLimit = 0; callLimit < 5000; callLimit++) {
final ModelCountingFunction me = ModelCountingFunction.builder(formula.variables(f)).build();
final LngResult<BigInteger> result = me.apply(solver, new CallLimitComputationHandler(callLimit));
assertThat(result.isSuccess()).isFalse();
assertThatThrownBy(result::getResult).isInstanceOf(IllegalStateException.class);
}
}

private List<Variable> generateVariables(final FormulaFactory f, final int n) {
final List<Variable> result = new ArrayList<>(n);
for (int i = 0; i < n; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
package com.booleworks.logicng.solvers.functions;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.booleworks.logicng.LongRunningTag;
import com.booleworks.logicng.RandomTag;
import com.booleworks.logicng.collections.LngBooleanVector;
import com.booleworks.logicng.collections.LngIntVector;
Expand All @@ -15,6 +17,7 @@
import com.booleworks.logicng.formulas.FormulaFactory;
import com.booleworks.logicng.formulas.TestWithFormulaContext;
import com.booleworks.logicng.formulas.Variable;
import com.booleworks.logicng.handlers.CallLimitComputationHandler;
import com.booleworks.logicng.handlers.LngResult;
import com.booleworks.logicng.handlers.NumberOfModelsHandler;
import com.booleworks.logicng.io.parsers.ParserException;
Expand All @@ -27,6 +30,7 @@
import com.booleworks.logicng.solvers.functions.modelenumeration.splitprovider.LeastCommonVariablesProvider;
import com.booleworks.logicng.solvers.functions.modelenumeration.splitprovider.MostCommonVariablesProvider;
import com.booleworks.logicng.solvers.functions.modelenumeration.splitprovider.SplitVariableProvider;
import com.booleworks.logicng.testutils.PigeonHoleGenerator;
import com.booleworks.logicng.util.FormulaRandomizer;
import com.booleworks.logicng.util.FormulaRandomizerConfig;
import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -187,12 +191,8 @@ public void testDontCareVariables3() throws ParserException {
ModelEnumerationConfig.builder().strategy(DefaultModelEnumerationStrategy.builder()
.splitVariableProvider(splitProvider).maxNumberOfModels(3).build()).build();
final SatSolver solver = SatSolver.newSolver(f);
final Formula formula = f.parse("A | B | (X & ~X)"); // X will be
// simplified out
// and become a
// don't care
// variable unknown
// by the solver
// X will be simplified out and become a don't care variable unknown by the solver
final Formula formula = f.parse("A | B | (X & ~X)");
solver.add(formula);
final SortedSet<Variable> variables = new TreeSet<>(f.variables("A", "B", "X"));
final BigInteger numberOfModels = solver.execute(ModelCountingFunction.builder(variables)
Expand Down Expand Up @@ -300,7 +300,7 @@ public void testCollector(final FormulaContext _c) {
assertThat(handler.getRollbackCalls()).isEqualTo(1);

collector.addModel(modelFromSolver2, solver, relevantIndices, handler);
final List<Model> rollbackModels = collector.rollbackAndReturnModels(solver, handler);
final List<Model> rollbackModels = collector.rollbackAndReturnModels(solver, handler).getResult();
assertThat(rollbackModels).containsExactly(expectedModel2);
assertThat(collector.getResult()).isEqualTo(result1);
assertThat(handler.getFoundModels()).isEqualTo(3);
Expand All @@ -317,9 +317,23 @@ public void testCollector(final FormulaContext _c) {

collector.rollback(handler);
assertThat(collector.getResult()).isEqualTo(result2);
assertThat(collector.rollbackAndReturnModels(solver, handler)).isEmpty();
assertThat(collector.rollbackAndReturnModels(solver, handler).getResult()).isEmpty();
assertThat(handler.getFoundModels()).isEqualTo(4);
assertThat(handler.getCommitCalls()).isEqualTo(2);
assertThat(handler.getRollbackCalls()).isEqualTo(4);
}

@Test
@LongRunningTag
public void testComputationHandlerExitPoints() {
final Formula formula = new PigeonHoleGenerator(f).generate(10).negate(f);
final SatSolver solver = SatSolver.newSolver(f);
solver.add(formula);
for (int callLimit = 0; callLimit < 5000; callLimit++) {
final ModelCountingFunction me = ModelCountingFunction.builder(formula.variables(f)).build();
final LngResult<BigInteger> result = me.apply(solver, new CallLimitComputationHandler(callLimit));
assertThat(result.isSuccess()).isFalse();
assertThatThrownBy(result::getResult).isInstanceOf(IllegalStateException.class);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static java.util.Collections.emptySortedSet;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.booleworks.logicng.LongRunningTag;
import com.booleworks.logicng.RandomTag;
Expand All @@ -23,6 +24,7 @@
import com.booleworks.logicng.formulas.Literal;
import com.booleworks.logicng.formulas.TestWithFormulaContext;
import com.booleworks.logicng.formulas.Variable;
import com.booleworks.logicng.handlers.CallLimitComputationHandler;
import com.booleworks.logicng.handlers.LngResult;
import com.booleworks.logicng.handlers.NumberOfModelsHandler;
import com.booleworks.logicng.io.parsers.ParserException;
Expand All @@ -37,6 +39,7 @@
import com.booleworks.logicng.solvers.functions.modelenumeration.splitprovider.MostCommonVariablesProvider;
import com.booleworks.logicng.solvers.functions.modelenumeration.splitprovider.SplitVariableProvider;
import com.booleworks.logicng.solvers.sat.SatSolverConfig;
import com.booleworks.logicng.testutils.PigeonHoleGenerator;
import com.booleworks.logicng.util.FormulaHelper;
import com.booleworks.logicng.util.FormulaRandomizer;
import com.booleworks.logicng.util.FormulaRandomizerConfig;
Expand All @@ -50,6 +53,7 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
Expand Down Expand Up @@ -507,8 +511,8 @@ public void testCollector(final FormulaContext _c) {
assertThat(handler.getRollbackCalls()).isEqualTo(1);

collector.addModel(modelFromSolver2, solver, relevantIndices, handler);
final List<Model> rollbackModels = collector.rollbackAndReturnModels(solver, handler);
assertThat(rollbackModels).containsExactly(expectedModel2);
final LngResult<List<Model>> rollbackModels = collector.rollbackAndReturnModels(solver, handler);
assertThat(rollbackModels.getResult()).containsExactly(expectedModel2);
assertThat(collector.getResult()).isEqualTo(result1);
assertThat(handler.getFoundModels()).isEqualTo(3);
assertThat(handler.getCommitCalls()).isEqualTo(1);
Expand All @@ -524,7 +528,7 @@ public void testCollector(final FormulaContext _c) {

collector.rollback(handler);
assertThat(collector.getResult()).isEqualTo(result2);
assertThat(collector.rollbackAndReturnModels(solver, handler)).isEmpty();
assertThat(collector.rollbackAndReturnModels(solver, handler).getResult()).isEmpty();
assertThat(handler.getFoundModels()).isEqualTo(4);
assertThat(handler.getCommitCalls()).isEqualTo(2);
assertThat(handler.getRollbackCalls()).isEqualTo(4);
Expand Down Expand Up @@ -592,6 +596,39 @@ public void testUnknownVariableNotOccurringInModel() {
assertThat(models.get(0).getLiterals()).contains(a);
}

@Test
@LongRunningTag
public void testPartialResults() {
final Map<Integer, Integer> callLimitToExpectedModels = Map.of(1, 0, 3, 1, 9, 3, 29, 9, 222, 73, 420, 139);
final Formula formula = new PigeonHoleGenerator(f).generate(10).negate(f);
for (final Map.Entry<Integer, Integer> callLimitAndExpectedModels : callLimitToExpectedModels.entrySet()) {
final SatSolver solver = SatSolver.newSolver(f);
solver.add(formula);
final Integer callLimit = callLimitAndExpectedModels.getKey();
final Integer expectedModels = callLimitAndExpectedModels.getValue();
final ModelEnumerationFunction me = ModelEnumerationFunction.builder(formula.variables(f)).build();
final LngResult<List<Model>> result = me.apply(solver, new CallLimitComputationHandler(callLimit));
assertThat(result.isSuccess()).isFalse();
assertThat(result.getCancelCause()).isNotNull();
assertThat(result.getPartialResult().size()).isEqualTo(expectedModels);
assertThatThrownBy(result::getResult).isInstanceOf(IllegalStateException.class);
}
}

@Test
@LongRunningTag
public void testComputationHandlerExitPoints() {
final Formula formula = new PigeonHoleGenerator(f).generate(10).negate(f);
final SatSolver solver = SatSolver.newSolver(f);
solver.add(formula);
for (int callLimit = 0; callLimit < 5000; callLimit++) {
final ModelEnumerationFunction me = ModelEnumerationFunction.builder(formula.variables(f)).build();
final LngResult<List<Model>> result = me.apply(solver, new CallLimitComputationHandler(callLimit));
assertThat(result.isSuccess()).isFalse();
assertThatThrownBy(result::getResult).isInstanceOf(IllegalStateException.class);
}
}

private static List<Set<Literal>> modelsToSets(final List<Model> models) {
return models.stream().map(x -> new HashSet<>(x.getLiterals())).collect(Collectors.toList());
}
Expand Down
Loading

0 comments on commit 51767e4

Please sign in to comment.