summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--convert_to_sat.py266
1 files changed, 266 insertions, 0 deletions
diff --git a/convert_to_sat.py b/convert_to_sat.py
new file mode 100644
index 0000000..d538b08
--- /dev/null
+++ b/convert_to_sat.py
@@ -0,0 +1,266 @@
+from z3 import *
+import sys
+
+
+variables = {}
+def reachable(graph, node_reachable):
+ # check reachability to node_reachable
+ clauses = []
+ # at least one pebble can be gotted to node_reachble (from config or otherwise)
+ clauses.append(1 <= Sum(list(variables[node_reachable]["in"].values())))
+ return clauses
+
+def build_model(graph, add_property, node_reachable, config=None, config_count=0):
+ clauses = []
+ for node in graph["v"]:
+ variables[node] = {"out":{}, "in":{}}
+ # creat symbolic variables for each edge in/out
+ for node in graph["v"]:
+ for edge in graph["e"][node]:
+ variables[node]["out"][edge] = Int(node+"_out_"+edge)
+ variables[edge]["in"][node] = Int(edge+"_in_"+node)
+ # add initial pebbles
+ variables[node]["in"]["config"] = Int(node+"_in_config")
+
+
+ if config is not None:
+ for node in graph["v"]:
+ clauses.append(variables[node]["in"]["config"] == config[node])
+ else:
+ sum_terms = []
+ for node in graph["v"]:
+ clauses.append(variables[node]["in"]["config"] >= 0)
+ sum_terms.append(variables[node]["in"]["config"])
+ clauses.append(Sum(sum_terms) == config_count)
+
+ # make edges linked
+ for node in graph["v"]:
+ for edge in graph["e"][node]:
+ clauses.append(variables[node]["out"][edge] == variables[edge]["in"][node])
+
+ # encode possible moves
+ # make sure we don't move too many total from pebble
+ for node in graph["v"]:
+ sum_out = Sum(list(variables[node]["out"].values()))
+ sum_in = Sum(list(variables[node]["in"].values()))
+ # The number of pebbles out from a node is no more than half the sum
+ # of all pebbles that are moved into a node
+ clauses.append( Implies(sum_in % 2 == 0, sum_out <= (sum_in/2 )) )
+ clauses.append( Implies(sum_in % 2 == 1, sum_out <= ((sum_in-1)/2 )) )
+ # make sure each move is a valid amount
+ for node in graph["v"]:
+ for in_var in variables[node]["in"].values():
+ clauses.append(in_var >= 0)
+ for out_var in variables[node]["out"].values():
+ sum_in = Sum(list(variables[node]["in"].values()))
+ clauses.append(
+ Implies(And(sum_in % 2 == 0, sum_in > 2),
+ And(0 <= out_var, out_var <= (sum_in/2 ))) )
+ clauses.append(
+ Implies(And(sum_in % 2 == 1, sum_in > 2),
+ And(0 <= out_var, out_var <= ((sum_in-1)/2 ))) )
+
+ # encode cycle lemma
+ # TODO get all cycles, and remove them?
+ for node in graph["v"]:
+ # if we along one edge, do not move back along it
+ for edge in graph["e"][node]:
+ in1 = variables[node]["in"][edge]
+ in2 = variables[edge]["in"][node]
+ clauses.append(Implies(in1 > 0, in2 == 0))
+ clauses += add_property(graph, node_reachable)
+
+ return And(clauses)
+
+def is_sat(f, output=True):
+ s = Solver()
+ s.add(f)
+ if output:
+ print(f)
+ if s.check() == sat:
+ if output:
+ print("SAT")
+ model = s.model()
+ # print(model)
+ moves = []
+ for key in model:
+ if "_in_" in key.name() and 0<model[key].as_long():
+ parts = key.name().split("_in_")
+ if parts[1] != "config":
+ moves.append(parts[1] + " sends " + str(model[key])+" pebbles to " +parts[0])
+ moves.sort()
+ if output:
+ print("\n".join(moves))
+ return moves
+ if output:
+ print("UNSAT")
+ return None
+
+def test_simple_path():
+ config = {
+ "a": 4,
+ "b": 0,
+ "c": 0,
+ }
+ graph = {
+ "v": ["a", "b", "c"],
+ "e": {
+ "a": ["b"],
+ "b": ["a", "c"],
+ "c": ["b"],
+ },
+ }
+ node_reachable = "c"
+ run_reachable_test_case(graph, config, node_reachable, True)
+
+ config = {
+ "a": 3,
+ "b": 0,
+ "c": 0,
+ }
+ run_reachable_test_case(graph, config, node_reachable, False)
+
+ config = {
+ "a": 5,
+ "b": 0,
+ "c": 0,
+ }
+ run_reachable_test_case(graph, config, node_reachable, True)
+
+ config = {
+ "a": 2,
+ "b": 1,
+ "c": 0,
+ }
+ run_reachable_test_case(graph, config, node_reachable, True)
+
+def test_cycle():
+ config = {
+ "a": 4,
+ "b": 0,
+ "c": 0,
+ "d": 0,
+ }
+ graph = {
+ "v": ["a", "b", "c", "d"],
+ "e": {
+ "a": ["b", "d"],
+ "b": ["a", "c"],
+ "c": ["b", "d"],
+ "d": ["a", "c"],
+ },
+ }
+ node_reachable = "c"
+ run_reachable_test_case(graph, config, node_reachable, True)
+
+def test_merge():
+ # a
+ # / \
+ # b c
+ # \ /
+ # d
+ # |
+ # e
+ graph = {
+ "v": ["a", "b", "c", "d", "e"],
+ "e": {
+ "a": ["b", "c"],
+ "b": ["a", "d"],
+ "c": ["a", "d"],
+ "d": ["b", "c", "e"],
+ "e": ["d"],
+ },
+ }
+ node_reachable = "e"
+ config = {
+ "a": 4,
+ "b": 1,
+ "c": 1,
+ "d": 0,
+ "e": 0,
+ }
+ run_reachable_test_case(graph, config, node_reachable, True)
+
+ config = {
+ "a": 4,
+ "b": 1,
+ "c": 0,
+ "d": 0,
+ "e": 0,
+ }
+ run_reachable_test_case(graph, config, node_reachable, False)
+
+def test_lemke():
+ # /-x-\
+ # / | | \
+ # a-b c d
+ # \ \ | /
+ # e--f
+ # \/
+ # v
+ graph = {
+ "v": ["a", "b", "c", "d", "e", "f", "v", "x"],
+ "e": {
+ "a": ["b", "e", "x"],
+ "b": ["a", "f", "x"],
+ "c": ["x", "f"],
+ "d": ["x", "f"],
+ "e": ["a", "f", "v"],
+ "f": ["b", "c", "d", "e", "v"],
+ "v": ["e", "f"],
+ "x": ["a", "b", "c", "d"],
+ },
+ }
+ config = {
+ "a": 0,
+ "b": 0,
+ "c": 0,
+ "d": 0,
+ "e": 0,
+ "f": 0,
+ "v": 0,
+ "x": 8,
+ }
+ run_reachable_test_case(graph, config, "v", True)
+
+ config = {
+ "a": 1,
+ "b": 1,
+ "c": 1,
+ "d": 1,
+ "e": 0,
+ "f": 0,
+ "v": 0,
+ "x": 4,
+ }
+ run_reachable_test_case(graph, config, "v", True)
+ run_pebbling_number_test_case(graph, 7)
+
+def run_reachable_test_case(graph, config, node_reachable, expected_is_sat):
+ variables = {}
+ z3_model = build_model(graph, reachable, node_reachable, config=config)
+ out = is_sat(z3_model, output=False)
+ if (expected_is_sat and not out) or (not expected_is_sat and out):
+ print(z3_model)
+ print(out)
+ print("failed!")
+
+def run_pebbling_number_test_case(graph, count):
+ variables = {}
+ for node_reachable in graph["v"]:
+ # TODO model is wrong, we need to say there is not a configure in
+ # which this isn't possible. This is solvability?
+ z3_model = build_model(graph, reachable, node_reachable, config_count=count)
+ out = is_sat(z3_model, output=False)
+ if not out:
+ print("failed peb. number: ", count, node_reachable)
+ print(out)
+
+def main():
+ test_simple_path()
+ test_cycle()
+ test_merge()
+ test_lemke()
+
+if __name__ == '__main__':
+ main()