Skip to content

Commit

Permalink
Add the choices
Browse files Browse the repository at this point in the history
  • Loading branch information
DariusNafar committed Nov 30, 2024
1 parent 7a149da commit cf1b119
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions test_regr/examples/orbs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from graph import get_graph
parser = argparse.ArgumentParser(description='Check a csp structure and atmostal/atleastal constraint in domiknows')
parser.add_argument('--colored', dest='colored', default=True,action='store_true',help="color every orb")
parser.add_argument('--constraint', dest='constraint',default="foreach_IfL_atleastL_atmostL", choices=["None","simple_constraint","foreach_bag_existsL","foreach_bag_existsL_notL","foreach_bag_atLeastAL","foreach_bag_atMostAL"], help="Choose a constraint")
parser.add_argument('--constraint', dest='constraint',default="foreach_IfL_atleastL_atmostL",
choices=["None","simple_constraint","foreach_bag_existsL","foreach_bag_existsL_notL","foreach_bag_atLeastAL","foreach_bag_atMostAL","foreach_IfL_atleastL_bag_existsL_notL","foreach_IfL_atleastL_bag_existsL","foreach_IfL_atleastL_atmostL"], help="Choose a constraint")
parser.add_argument('--atmostaL', dest='atmostaL',default=1,type=int)
parser.add_argument('--atleastaL', dest='atleastaL',default=5,type=int)

Expand Down Expand Up @@ -63,13 +64,11 @@ def connect(x,y):
for csp_range_datanode in datanode.getChildDataNodes():
print(*[int(orb_node.getAttribute('<colored_orbs>',"local/softmax")[1].item()) for orb_node in csp_range_datanode.getChildDataNodes()],end=" | ")

#assert sum([int(orb_node.getAttribute('<colored_orbs>',"local/softmax")[1].item()) for orb_node in csp_range_datanode.getChildDataNodes()])==0
datanode.inferILPResults()
datanode.inferILPResults()

print("\n\nafter inference")
print("orb color:",end="")
for csp_range_datanode in datanode.getChildDataNodes():
print(*[int(orb_node.getAttribute('<colored_orbs>',"ILP").item()) for orb_node in csp_range_datanode.getChildDataNodes()],end=" | ")

#assert sum([int(orb_node.getAttribute('<colored_orbs>',"ILP").item()) for orb_node in csp_range_datanode.getChildDataNodes()])==2
print()

0 comments on commit cf1b119

Please sign in to comment.