From cb0248b594f1aa85b4fb378787dbaac00457cd0b Mon Sep 17 00:00:00 2001 From: jiening Date: Thu, 7 May 2026 15:42:45 +0800 Subject: [PATCH] feat: update monabs --- aria/monabs/cores/__init__.py | 13 +- aria/monabs/cores/con_check.py | 299 ------------ aria/monabs/cores/con_check_pysmt.py | 302 ------------- aria/monabs/cores/dis_check.py | 116 ----- aria/monabs/cores/dis_check_pysmt.py | 102 +++-- aria/monabs/cores/new_check_pysmt.py | 134 ++++++ aria/monabs/cores/unary_check.py | 119 ----- aria/monabs/cores/unary_check_pysmt.py | 207 +++++---- aria/monabs/tests/test_pysmt_monabs.py | 231 +++++++--- aria/monabs/utils/__init__.py | 5 +- aria/monabs/utils/config.py | 1 + aria/monabs/utils/formular_generator.py | 325 ------------- aria/monabs/utils/logger.py | 42 ++ aria/monabs/utils/parse_monabs.py | 578 ------------------------ aria/monabs/utils/parse_monabs_pysmt.py | 6 + aria/monabs/utils/utils.py | 10 + 16 files changed, 563 insertions(+), 1927 deletions(-) delete mode 100644 aria/monabs/cores/con_check.py delete mode 100644 aria/monabs/cores/con_check_pysmt.py delete mode 100644 aria/monabs/cores/dis_check.py create mode 100644 aria/monabs/cores/new_check_pysmt.py delete mode 100644 aria/monabs/cores/unary_check.py create mode 100644 aria/monabs/utils/config.py delete mode 100644 aria/monabs/utils/formular_generator.py create mode 100644 aria/monabs/utils/logger.py delete mode 100644 aria/monabs/utils/parse_monabs.py create mode 100644 aria/monabs/utils/utils.py diff --git a/aria/monabs/cores/__init__.py b/aria/monabs/cores/__init__.py index f10322af..a277d71a 100644 --- a/aria/monabs/cores/__init__.py +++ b/aria/monabs/cores/__init__.py @@ -1,15 +1,20 @@ """Core checking functions for monadic predicate abstraction.""" -from .unary_check import ( +# LS +from .unary_check_pysmt import ( unary_check, - unary_check_incremental, unary_check_cached, + unary_check_incremental, unary_check_incremental_cached, ) -from .dis_check import ( +# OA +from .dis_check_pysmt import ( disjunctive_check_cached, disjunctive_check_incremental_cached, ) -from .con_check import conjunctive_check, conjunctive_check_incremental +# New Algorithms +from .new_check_pysmt import ( + core_lit_filter, +) \ No newline at end of file diff --git a/aria/monabs/cores/con_check.py b/aria/monabs/cores/con_check.py deleted file mode 100644 index cedf39e6..00000000 --- a/aria/monabs/cores/con_check.py +++ /dev/null @@ -1,299 +0,0 @@ -""" -Conjunctive/disjunctive satisfiability helpers with optional caching. - -Result encoding (shared across helpers): -- 1: constraint satisfiable under the given precondition -- 0: unsatisfiable -- 2: unknown (e.g., solver returned unknown) - -""" - -import z3 -from typing import List - - -def unary_check_cached( - precond: z3.ExprRef, - cnt_list: List[z3.ExprRef], - results: List[int], - check_list: List[int], -): - """Non-incremental check with model-based caching over a subset of constraints.""" - for i in check_list: - if results[i] is not None: - continue - solver = z3.Solver() - solver.add(precond) - solver.add(cnt_list[i]) # Add the current constraint - res = solver.check() - if res == z3.sat: - model = solver.model() - results[i] = 1 - for j in check_list: - if results[j] is None and z3.is_true( - model.eval(cnt_list[j], model_completion=True) - ): - results[j] = 1 - elif res == z3.unsat: - results[i] = 0 - else: - results[i] = 2 - - -def unary_check_incremental_cached( - solver: z3.Solver, - cnt_list: List[z3.ExprRef], - results: List[int], - check_list: List[int], -): - """Incremental version of unary_check_cached using a shared solver with push/pop.""" - for i in check_list: - if results[i] is not None: - continue - solver.push() # Save the current state - solver.add(cnt_list[i]) # Add the current constraint - res = solver.check() - if res == z3.sat: - model = solver.model() - results[i] = 1 - for j in check_list: - if results[j] is None and z3.is_true( - model.eval(cnt_list[j], model_completion=True) - ): - results[j] = 1 - elif res == z3.unsat: - results[i] = 0 - else: - results[i] = 2 - solver.pop() # Restore the state - - -def disjunctive_check_incremental_cached( - solver: z3.Solver, - cnt_list: List[z3.ExprRef], - results: List[int], - check_list: List[int], -): - """Recursively solve a disjunction of pending constraints, caching satisfiable members.""" - f = z3.BoolVal(False) - conditions = [cnt_list[i] for i in check_list if results[i] is None] - - if len(conditions) == 0: - return - - f = z3.Or(conditions) - - if z3.is_false(f): - return - - solver.push() - solver.add(f) - res = solver.check() - if res == z3.unsat: - for i in check_list: - results[i] = 0 - solver.pop() - elif res == z3.sat: - m = solver.model() - solver.pop() - new_check_list = [] - for i in check_list: - if results[i] is None and z3.is_true( - m.eval(cnt_list[i], model_completion=True) - ): - results[i] = 1 - elif results[i] is None: - new_check_list.append(i) - disjunctive_check_incremental_cached(solver, cnt_list, results, new_check_list) - - -def conjunctive_check_incremental( - precond: z3.ExprRef, cnt_list: List[z3.ExprRef], alogorithm: int = 0 -) -> List: - """ - Perform a conjunctive satisfiability check on a list of constraints under a given precondition. - This function checks whether the conjunction of a set of constraints (`cnt_list`) is satisfiable - under a given precondition (`precond`). It uses a Z3 solver to perform the satisfiability checks - and supports different algorithms for handling unsatisfiable cores. - Args: - precond (z3.ExprRef): The precondition to be added to the solver. - cnt_list (List[z3.ExprRef]): A list of Z3 expressions representing the constraints to be checked. - alogorithm (int): The algorithm to use for handling unsatisfiable cores. - Options are: - - 0: Unary check with caching. - - 1: Incremental unary check with caching. - - 2: Incremental disjunctive check with caching. - Returns: - List: A list of results where each index corresponds to the satisfiability of the respective - constraint in `cnt_list`. A value of `1` indicates satisfiable, `0` indicates unsatisfiable, - and `2` indicates unknown. - Notes: - - Unsatisfiable cores are handled by moving them to a waiting list for further processing - based on the selected algorithm. - """ - results = [None] * len(cnt_list) - solver = z3.Solver() - waiting_list_idx = [] - queue = [list(range(len(cnt_list)))] - i = 0 - while queue: - i += 1 - current_subset = queue.pop(0) - solver.push() - for idx in current_subset: - solver.assert_and_track(cnt_list[idx], str(idx)) - solver_result = solver.check() - if ( - solver_result == z3.sat - ): # no conflicts within the predicates, no need to split - while True: - solver.add(precond) - solver_result = solver.check() - solver.pop() - - if solver_result == z3.sat: - # All constraints are satisfiable - for idx in current_subset: - results[idx] = 1 - break - elif solver_result == z3.unsat: - # Move unsat core to waiting list - unsat_core = solver.unsat_core() - for idx in unsat_core: - current_subset.remove(int(str(idx))) - waiting_list_idx.append(int(str(idx))) - if len(current_subset) == 0: - break - solver.push() - for idx in current_subset: - solver.assert_and_track(cnt_list[idx], str(idx)) - - elif ( - solver_result == z3.unsat - ): # conflicts within the predicates, need to split - solver.pop() - unsat_core_indices = {int(c.decl().name()) for c in solver.unsat_core()} - unsat_set_indices = list(unsat_core_indices) - sat_set_indices = [i for i in current_subset if i not in unsat_core_indices] - if len(unsat_set_indices) == 1: - waiting_list_idx.append(unsat_set_indices[0]) - if sat_set_indices: - queue.append(sat_set_indices) - else: - subsets = [ - [unsat_set_indices[i]] for i in range(len(unsat_set_indices)) - ] - for i, sat_item in enumerate(sat_set_indices): - subsets[i % len(unsat_set_indices)].append(sat_item) - queue.extend(subsets) - - solver.add(precond) - if alogorithm == 0: - unary_check_cached(precond, cnt_list, results, waiting_list_idx) - elif alogorithm == 1: - unary_check_incremental_cached(solver, cnt_list, results, waiting_list_idx) - elif alogorithm == 2: - disjunctive_check_incremental_cached( - solver, cnt_list, results, waiting_list_idx - ) - else: - raise ValueError("Invalid algorithm choice. Choose 0, 1, or 2.") - - return results - - -def conjunctive_check( - precond: z3.ExprRef, cnt_list: List[z3.ExprRef], alogorithm: int = 0 -) -> List: - """ - Perform a conjunctive satisfiability check on a list of constraints under a given precondition. - This function checks whether the conjunction of a set of constraints (`cnt_list`) is satisfiable - under a given precondition (`precond`). It uses a Z3 solver to perform the satisfiability checks - and supports different algorithms for handling unsatisfiable cores. - Args: - precond (z3.ExprRef): The precondition to be added to the solver. - cnt_list (List[z3.ExprRef]): A list of Z3 expressions representing the constraints to be checked. - alogorithm (int): The algorithm to use for handling unsatisfiable cores. - Options are: - - 0: Unary check with caching. - - 1: Incremental unary check with caching. - - 2: Incremental disjunctive check with caching. - Returns: - List: A list of results where each index corresponds to the satisfiability of the respective - constraint in `cnt_list`. A value of `1` indicates satisfiable, `0` indicates unsatisfiable, - and `2` indicates unknown. - Notes: - - Unsatisfiable cores are handled by moving them to a waiting list for further processing - based on the selected algorithm. - """ - results = [None] * len(cnt_list) - waiting_list_idx = [] - queue = [list(range(len(cnt_list)))] - i = 0 - while queue: - i += 1 - current_subset = queue.pop(0) - solver_split = z3.Solver() - for idx in current_subset: - solver_split.assert_and_track(cnt_list[idx], str(idx)) - solver_result = solver_split.check() - if ( - solver_result == z3.sat - ): # no conflicts within the predicates, no need to split - while True: - solver_check = z3.Solver() - solver_check.add(precond) - for idx in current_subset: - solver_check.assert_and_track(cnt_list[idx], str(idx)) - solver_result = solver_check.check() - - if solver_result == z3.sat: - # All constraints are satisfiable - for idx in current_subset: - results[idx] = 1 - break - elif solver_result == z3.unsat: - # Move unsat core to waiting list - unsat_core = solver_check.unsat_core() - for idx in unsat_core: - current_subset.remove(int(str(idx))) - waiting_list_idx.append(int(str(idx))) - if len(current_subset) == 0: - break - - elif ( - solver_result == z3.unsat - ): # conflicts within the predicates, need to split - unsat_core_indices = { - int(c.decl().name()) for c in solver_split.unsat_core() - } - unsat_set_indices = list(unsat_core_indices) - sat_set_indices = [i for i in current_subset if i not in unsat_core_indices] - if len(unsat_set_indices) == 1: - waiting_list_idx.append(unsat_set_indices[0]) - if sat_set_indices: - queue.append(sat_set_indices) - else: - subsets = [ - [unsat_set_indices[i]] for i in range(len(unsat_set_indices)) - ] - for i, sat_item in enumerate(sat_set_indices): - subsets[i % len(unsat_set_indices)].append(sat_item) - queue.extend(subsets) - - solver_fallback = z3.Solver() - solver_fallback.add(precond) - if alogorithm == 0: - unary_check_cached(precond, cnt_list, results, waiting_list_idx) - elif alogorithm == 1: - unary_check_incremental_cached( - solver_fallback, cnt_list, results, waiting_list_idx - ) - elif alogorithm == 2: - disjunctive_check_incremental_cached( - solver_fallback, cnt_list, results, waiting_list_idx - ) - else: - raise ValueError("Invalid algorithm choice. Choose 0, 1, or 2.") - - return results diff --git a/aria/monabs/cores/con_check_pysmt.py b/aria/monabs/cores/con_check_pysmt.py deleted file mode 100644 index 8776b4bc..00000000 --- a/aria/monabs/cores/con_check_pysmt.py +++ /dev/null @@ -1,302 +0,0 @@ -""" -PySMT versions of conjunctive/disjunctive satisfiability helpers with optional caching. - -Result encoding (shared across helpers): -- 1: constraint satisfiable under the given precondition -- 0: unsatisfiable -- 2: unknown (e.g., solver returned unknown) -""" - -from typing import List, Optional - -from pysmt.exceptions import SolverReturnedUnknownResultError -from pysmt.shortcuts import Or, Solver - - -def _check(solver: Solver) -> str: - """Normalize solver status to sat/unsat/unknown.""" - try: - res = solver.solve() - except SolverReturnedUnknownResultError: - return "unknown" - - if res is True: - return "sat" - if res is False: - return "unsat" - return "unknown" - - -def _core_indices(core) -> List[int]: - """Extract integer indices from a PySMT named unsat core.""" - indices: List[int] = [] - # core may be a dict(name -> formula) or an iterable of names - items = core.keys() if hasattr(core, "keys") else core - for item in items: - name = str(item) - if name.startswith("idx_"): - try: - indices.append(int(name[4:])) - continue - except ValueError: - pass - try: - indices.append(int(name)) - except ValueError: - # Skip entries we cannot map back to indices - continue - return indices - - -def unary_check_cached( - precond, cnt_list: List, results: List[Optional[int]], check_list: List[int] -): - """Non-incremental check with model-based caching over a subset of constraints.""" - for i in check_list: - if results[i] is not None: - continue - solver = Solver() - solver.add_assertion(precond) - solver.add_assertion(cnt_list[i]) - res = _check(solver) - if res == "sat": - model = solver.get_model() - results[i] = 1 - for j in check_list: - if results[j] is None and model.get_value(cnt_list[j]).is_true(): - results[j] = 1 - elif res == "unsat": - results[i] = 0 - else: - results[i] = 2 - - -def unary_check_incremental_cached( - solver: Solver, cnt_list: List, results: List[Optional[int]], check_list: List[int] -): - """Incremental version of unary_check_cached using a shared solver with push/pop.""" - for i in check_list: - if results[i] is not None: - continue - solver.push() - solver.add_assertion(cnt_list[i]) - res = _check(solver) - if res == "sat": - model = solver.get_model() - results[i] = 1 - for j in check_list: - if results[j] is None and model.get_value(cnt_list[j]).is_true(): - results[j] = 1 - solver.pop() - elif res == "unsat": - core = solver.get_named_unsat_core() - solver.pop() - results[i] = 0 - # core content not used here, but keep symmetry - else: - results[i] = 2 - solver.pop() - - -def disjunctive_check_incremental_cached( - solver: Solver, cnt_list: List, results: List[Optional[int]], check_list: List[int] -): - """Recursively solve a disjunction of pending constraints, caching satisfiable members.""" - conditions = [cnt_list[i] for i in check_list if results[i] is None] - if len(conditions) == 0: - return - - f = Or(conditions) - if f.is_false(): - return - - solver.push() - solver.add_assertion(f) - res = _check(solver) - if res == "unsat": - for i in check_list: - results[i] = 0 - solver.pop() - elif res == "sat": - m = solver.get_model() - solver.pop() - new_check_list = [] - for i in check_list: - if results[i] is None and m.get_value(cnt_list[i]).is_true(): - results[i] = 1 - elif results[i] is None: - new_check_list.append(i) - disjunctive_check_incremental_cached(solver, cnt_list, results, new_check_list) - else: - solver.pop() - for i in check_list: - if results[i] is None: - results[i] = 2 - - -def conjunctive_check_incremental( - precond, cnt_list: List, alogorithm: int = 0 -) -> List[int]: - """ - Perform a conjunctive satisfiability check on a list of constraints under a given precondition. - Mirrors the Z3-based version but uses PySMT solvers and APIs. - """ - results: List[Optional[int]] = [None] * len(cnt_list) - solver = Solver(name="z3", unsat_cores_mode="named") # need unsat cores - waiting_list_idx: List[int] = [] - queue: List[List[int]] = [list(range(len(cnt_list)))] - - while queue: - current_subset = queue.pop(0) - solver.push() - for idx in current_subset: - solver.add_assertion(cnt_list[idx], named=f"idx_{idx}") - solver_result = _check(solver) - if ( - solver_result == "sat" - ): # no conflicts within the predicates, no need to split - while True: - solver.add_assertion(precond) - solver_result = _check(solver) - if solver_result == "sat": - solver.pop() - for idx in current_subset: - results[idx] = 1 - break - elif solver_result == "unsat": - unsat_core = _core_indices(solver.get_named_unsat_core()) - solver.pop() - for idx in unsat_core: - if idx in current_subset: - current_subset.remove(idx) - if idx not in waiting_list_idx: - waiting_list_idx.append(idx) - if len(current_subset) == 0: - break - solver.push() - for idx in current_subset: - solver.add_assertion(cnt_list[idx], named=f"idx_{idx}") - else: # unknown - solver.pop() - for idx in current_subset: - if results[idx] is None: - results[idx] = 2 - break - - elif solver_result == "unsat": # conflicts within the predicates, need to split - core_indices = set(_core_indices(solver.get_named_unsat_core())) - solver.pop() - unsat_set_indices = list(core_indices) - sat_set_indices = [i for i in current_subset if i not in core_indices] - if len(unsat_set_indices) == 1: - waiting_list_idx.append(unsat_set_indices[0]) - if sat_set_indices: - queue.append(sat_set_indices) - else: - subsets = [ - [unsat_set_indices[i]] for i in range(len(unsat_set_indices)) - ] - for i, sat_item in enumerate(sat_set_indices): - subsets[i % len(unsat_set_indices)].append(sat_item) - queue.extend(subsets) - else: - solver.pop() - for idx in current_subset: - if results[idx] is None: - results[idx] = 2 - - solver.add_assertion(precond) - if alogorithm == 0: - unary_check_cached(precond, cnt_list, results, waiting_list_idx) - elif alogorithm == 1: - unary_check_incremental_cached(solver, cnt_list, results, waiting_list_idx) - elif alogorithm == 2: - disjunctive_check_incremental_cached( - solver, cnt_list, results, waiting_list_idx - ) - else: - raise ValueError("Invalid algorithm choice. Choose 0, 1, or 2.") - - return results # type: ignore[return-value] - - -def conjunctive_check(precond, cnt_list: List, alogorithm: int = 0) -> List[int]: - """ - Non-incremental version of conjunctive satisfiability check using PySMT. - """ - results: List[Optional[int]] = [None] * len(cnt_list) - waiting_list_idx: List[int] = [] - queue: List[List[int]] = [list(range(len(cnt_list)))] - i = 0 - while queue: - i += 1 - current_subset = queue.pop(0) - solver_split = Solver(name="z3", unsat_cores_mode="named") - for idx in current_subset: - solver_split.add_assertion(cnt_list[idx], named=f"idx_{idx}") - solver_result = _check(solver_split) - if ( - solver_result == "sat" - ): # no conflicts within the predicates, no need to split - while True: - solver_check = Solver(name="z3", unsat_cores_mode="named") - solver_check.add_assertion(precond) - for idx in current_subset: - solver_check.add_assertion(cnt_list[idx], named=f"idx_{idx}") - solver_result = _check(solver_check) - - if solver_result == "sat": - for idx in current_subset: - results[idx] = 1 - break - elif solver_result == "unsat": - unsat_core = _core_indices(solver_check.get_named_unsat_core()) - for idx in unsat_core: - if idx in current_subset: - current_subset.remove(idx) - if idx not in waiting_list_idx: - waiting_list_idx.append(idx) - if len(current_subset) == 0: - break - else: - for idx in current_subset: - if results[idx] is None: - results[idx] = 2 - break - - elif solver_result == "unsat": # conflicts within the predicates, need to split - unsat_core_indices = set(_core_indices(solver_split.get_named_unsat_core())) - unsat_set_indices = list(unsat_core_indices) - sat_set_indices = [i for i in current_subset if i not in unsat_core_indices] - if len(unsat_set_indices) == 1: - waiting_list_idx.append(unsat_set_indices[0]) - if sat_set_indices: - queue.append(sat_set_indices) - else: - subsets = [ - [unsat_set_indices[i]] for i in range(len(unsat_set_indices)) - ] - for i, sat_item in enumerate(sat_set_indices): - subsets[i % len(unsat_set_indices)].append(sat_item) - queue.extend(subsets) - else: - for idx in current_subset: - if results[idx] is None: - results[idx] = 2 - - solver_fallback = Solver(name="z3", unsat_cores_mode="named") - solver_fallback.add_assertion(precond) - if alogorithm == 0: - unary_check_cached(precond, cnt_list, results, waiting_list_idx) - elif alogorithm == 1: - unary_check_incremental_cached( - solver_fallback, cnt_list, results, waiting_list_idx - ) - elif alogorithm == 2: - disjunctive_check_incremental_cached( - solver_fallback, cnt_list, results, waiting_list_idx - ) - else: - raise ValueError("Invalid algorithm choice. Choose 0, 1, or 2.") - - return results # type: ignore[return-value] diff --git a/aria/monabs/cores/dis_check.py b/aria/monabs/cores/dis_check.py deleted file mode 100644 index 4012f6b1..00000000 --- a/aria/monabs/cores/dis_check.py +++ /dev/null @@ -1,116 +0,0 @@ -""" -Disjunctive over-approximation helpers. - -Given a shared precondition and a list of constraints, these routines try to -classify each constraint as: -- 1: satisfiable under the precondition -- 0: unsatisfiable under the precondition -- 2: unknown (solver returned unknown or not yet decided) - -They explore the disjunction of all still-unknown constraints, using caching -and (optionally) incremental push/pop to reduce solver calls. -""" - -from typing import List -import z3 - - -def compact_check_cached( - precond: z3.ExprRef, cnt_list: List[z3.ExprRef], res_label: List -): - """Recursive disjunctive check (fresh solver each recursion), with model caching.""" - f = z3.BoolVal(False) - - conditions = [] - for i in range(len(res_label)): - if res_label[i] == 2: - conditions.append(cnt_list[i]) - - if len(conditions) == 0: - return - - f = z3.Or(conditions) - - if z3.is_false(f): - return - - solver = z3.Solver() - g = z3.And(precond, f) - solver.add(g) - res = solver.check() - if res == z3.unsat: - for i in range(len(res_label)): - if res_label[i] == 2: - res_label[i] = 0 - elif res == z3.sat: - m = solver.model() - for i in range(len(res_label)): - if res_label[i] == 2 and z3.is_true( - m.eval(cnt_list[i], model_completion=True) - ): - res_label[i] = 1 - else: - return - compact_check_cached(precond, cnt_list, res_label) - - -def disjunctive_check_cached( - precond: z3.ExprRef, cnt_list: List[z3.ExprRef] -) -> List[int]: - """Entry point for cached disjunctive checking (non-incremental solver usage).""" - res = [2] * len(cnt_list) # 0 means unsat, 1 means sat, 2 means "unknown" - compact_check_cached(precond, cnt_list, res) - return res - - -def compact_check_incremental_cached( - solver: z3.Solver, - precond: z3.ExprRef, - cnt_list: List[z3.ExprRef], - res_label: List[int], -): - """Recursive disjunctive check using a shared solver with push/pop for efficiency.""" - f = z3.BoolVal(False) - - conditions = [] - for i, label in enumerate(res_label): - if label == 2: - conditions.append(cnt_list[i]) - - if len(conditions) == 0: - return - - f = z3.Or(conditions) - - if z3.is_false(f): - return - - solver.push() - solver.add(f) - res = solver.check() - if res == z3.unsat: - for i in range(len(res_label)): - if res_label[i] == 2: - res_label[i] = 0 - elif res == z3.sat: - m = solver.model() - for i in range(len(res_label)): - if res_label[i] == 2 and z3.is_true( - m.eval(cnt_list[i], model_completion=True) - ): - res_label[i] = 1 - else: - return - solver.pop() - compact_check_incremental_cached(solver, precond, cnt_list, res_label) - - -def disjunctive_check_incremental_cached( - precond: z3.ExprRef, cnt_list: List[z3.ExprRef] -) -> List[int]: - """Entry point for cached disjunctive checking with a shared incremental solver.""" - results = [2] * len(cnt_list) - solver = z3.Solver() - solver.add(precond) - compact_check_incremental_cached(solver, precond, cnt_list, results) - return results diff --git a/aria/monabs/cores/dis_check_pysmt.py b/aria/monabs/cores/dis_check_pysmt.py index 47b393b7..6b0c72a4 100644 --- a/aria/monabs/cores/dis_check_pysmt.py +++ b/aria/monabs/cores/dis_check_pysmt.py @@ -7,28 +7,38 @@ - 2: unknown (solver returned unknown or not yet decided) """ -from typing import List +from typing import Tuple, List from pysmt.exceptions import SolverReturnedUnknownResultError from pysmt.shortcuts import And, Or, Solver -def _check(solver: Solver) -> str: - """Normalize solver status to sat/unsat/unknown.""" +# ── Internal solver helper ─────────────────────────────────────────────────── + +def _check(solver: Solver, solver_calls: list[int]) -> str: + """ + Normalize solver status to sat/unsat/unknown/timeout. + """ + solver_calls[0] += 1 try: res = solver.solve() except SolverReturnedUnknownResultError: - return "unknown" + return "timeout" if res is True: return "sat" if res is False: return "unsat" + return "unknown" -def compact_check_cached(precond, cnt_list: List, res_label: List[int]): - """Recursive disjunctive check (fresh solver each recursion), with model caching.""" +# ── Algorithm 1: OA ─────────────────────────────────────────────────────────── + +def _compact_check_cached(precond, cnt_list: List, res_label: List[int], solver_calls: list[int], timeout_ms: int = 0): + """ + Recursive disjunctive check (fresh solver each recursion), with model caching. + """ conditions = [cnt_list[i] for i, lbl in enumerate(res_label) if lbl == 2] if len(conditions) == 0: @@ -39,34 +49,39 @@ def compact_check_cached(precond, cnt_list: List, res_label: List[int]): if f.is_false(): return - solver = Solver() - solver.add_assertion(And(precond, f)) - res = _check(solver) - if res == "unsat": - for i, lbl in enumerate(res_label): - if lbl == 2: - res_label[i] = 0 - elif res == "sat": - m = solver.get_model() - for i, lbl in enumerate(res_label): - if lbl == 2 and m.get_value(cnt_list[i]).is_true(): - res_label[i] = 1 - else: - return - compact_check_cached(precond, cnt_list, res_label) - - -def disjunctive_check_cached(precond, cnt_list: List) -> List[int]: - """Entry point for cached disjunctive checking (non-incremental solver usage).""" + with Solver(name="z3", solver_options={"timeout": timeout_ms}) as solver: + solver.add_assertion(And(precond, f)) + status = _check(solver, solver_calls) + if status == "unsat": + for i, lbl in enumerate(res_label): + if lbl == 2: + res_label[i] = 0 + elif status == "sat": + m = solver.get_model() + for i, lbl in enumerate(res_label): + if lbl == 2 and m.get_value(cnt_list[i]).is_true(): + res_label[i] = 1 + else: # timeout or unknown + return + _compact_check_cached(precond, cnt_list, res_label, solver_calls, timeout_ms) + + +def disjunctive_check_cached(precond, cnt_list: List, timeout_ms: int = 0) -> Tuple[List[int], int]: + """ + Entry point for cached disjunctive checking (non-incremental solver usage). + """ res = [2] * len(cnt_list) # 0 means unsat, 1 means sat, 2 means "unknown" - compact_check_cached(precond, cnt_list, res) - return res + solver_calls = [0] + _compact_check_cached(precond, cnt_list, res, solver_calls, timeout_ms) + return res, solver_calls[0] + +# ── Algorithm 2: OA-Inc ─────────────────────────────────────────────────────── -def compact_check_incremental_cached( - solver: Solver, precond, cnt_list: List, res_label: List[int] -): - """Recursive disjunctive check using a shared solver with push/pop for efficiency.""" +def _compact_check_incremental_cached(solver: Solver, precond, cnt_list: List, res_label: List[int], solver_calls: list[int]): + """ + Recursive disjunctive check using a shared solver with push/pop for efficiency. + """ conditions = [cnt_list[i] for i, lbl in enumerate(res_label) if lbl == 2] if len(conditions) == 0: @@ -79,27 +94,30 @@ def compact_check_incremental_cached( solver.push() solver.add_assertion(f) - res = _check(solver) - if res == "unsat": + status = _check(solver, solver_calls) + if status == "unsat": for i, lbl in enumerate(res_label): if lbl == 2: res_label[i] = 0 - elif res == "sat": + elif status == "sat": m = solver.get_model() for i, lbl in enumerate(res_label): if lbl == 2 and m.get_value(cnt_list[i]).is_true(): res_label[i] = 1 - else: + else: # timeout or unknown solver.pop() return solver.pop() - compact_check_incremental_cached(solver, precond, cnt_list, res_label) + _compact_check_incremental_cached(solver, precond, cnt_list, res_label, solver_calls) -def disjunctive_check_incremental_cached(precond, cnt_list: List) -> List[int]: - """Entry point for cached disjunctive checking with a shared incremental solver.""" +def disjunctive_check_incremental_cached(precond, cnt_list: List, timeout_ms: int = 0) -> Tuple[List[int], int]: + """ + Entry point for cached disjunctive checking with a shared incremental solver. + """ results = [2] * len(cnt_list) - solver = Solver() - solver.add_assertion(precond) - compact_check_incremental_cached(solver, precond, cnt_list, results) - return results + solver_calls = [0] + with Solver(name="z3", solver_options={"timeout": timeout_ms}) as solver: + solver.add_assertion(precond) + _compact_check_incremental_cached(solver, precond, cnt_list, results, solver_calls) + return results, solver_calls[0] diff --git a/aria/monabs/cores/new_check_pysmt.py b/aria/monabs/cores/new_check_pysmt.py new file mode 100644 index 00000000..790a259a --- /dev/null +++ b/aria/monabs/cores/new_check_pysmt.py @@ -0,0 +1,134 @@ +from typing import List, Optional, Set, Tuple + +from pysmt.exceptions import SolverReturnedUnknownResultError +from pysmt.shortcuts import And, Not, Or, Solver, BOOL, EqualsOrIff + + +# ── Internal solver helper ─────────────────────────────────────────────────── + +def _check(solver: Solver, solver_calls: list[int]) -> str: + """ + Normalize solver status to sat/unsat/unknown/timeout. + """ + solver_calls[0] += 1 + try: + res = solver.solve() + except SolverReturnedUnknownResultError: + return "timeout" + + if res is True: + return "sat" + if res is False: + return "unsat" + + return "unknown" + +# ── Algorithm: CoreLitFilter ─────────────────────────────────────────────── + +def _get_top_level_literals(formula) -> List: + """ + Extract top-level literals from a conjunction, returned as a list. + For (a ∧ b ∧ c), returns [a, b, c]. + For a single atom/literal, returns [formula]. + Does NOT recurse into disjunctions, ite, or implications — only And nodes. + """ + lits: List = [] + try: + if formula.is_and(): + for arg in formula.args(): + lits.extend(_get_top_level_literals(arg)) + else: + lits.append(formula) + except Exception: + pass + return lits + + +def _is_blocked(formula, forbidden_ids: Set[int]) -> bool: + """ + Return True if `formula` contains a top-level literal whose node_id + is in `forbidden_ids`. + """ + for lit in _get_top_level_literals(formula): + try: + if id(lit) in forbidden_ids or lit.node_id() in forbidden_ids: + return True + except Exception: + pass + return False + + +def core_lit_filter(precond, cnt_list: List, timeout_ms: int = 0, solver_calls: Optional[list] = None) -> Tuple[List[int], int]: + """ + Accumulate forbidden literals from UNSAT results and use them to pre-screen future predicates at zero solver cost. + """ + if solver_calls is None: + solver_calls = [0] + + n = len(cnt_list) + results = [2] * n + + # forbidden_ids: set of node_id() values of literals ℓ where φ ⊨ ¬ℓ + forbidden_ids: Set[int] = set() + FORBIDDEN_BUDGET = 64 # max literals to verify, limits extra solver calls + + with Solver(name="z3", solver_options={"timeout": timeout_ms}) as solver: + solver.add_assertion(precond) + + for i, cnt in enumerate(cnt_list): + if results[i] != 2: + continue + + # Free pre-screening (zero solver calls), the solver confirmed φ ∧ ℓ is UNSAT in a clean φ-only scope. + if forbidden_ids and _is_blocked(cnt, forbidden_ids): + results[i] = 0 + continue + + # Standard incremental check + solver.push() + solver.add_assertion(cnt) + status = _check(solver, solver_calls) + + if status == "timeout" or status == "unknown": + solver.pop() + return results, solver_calls[0] + + if status == "sat": + results[i] = 1 + model = solver.get_model() + for j in range(i + 1, n): + if results[j] == 2: + try: + if model.get_value(cnt_list[j]).is_true(): + results[j] = 1 + except Exception: + pass + solver.pop() + + else: # unsat + results[i] = 0 + # Collect candidate literals from p_i for forbidden verification. + candidates = [] + if len(forbidden_ids) < FORBIDDEN_BUDGET: + for lit in _get_top_level_literals(cnt): + try: + nid = lit.node_id() + except Exception: + nid = id(lit) + if nid not in forbidden_ids: + candidates.append((nid, lit)) + + solver.pop() + + # Now verify candidates under φ alone, UNSAT here genuinely means φ ⊨ ¬ℓ. + for nid, lit in candidates: + if len(forbidden_ids) >= FORBIDDEN_BUDGET: + break + solver.push() + solver.add_assertion(lit) + lit_status = _check(solver, solver_calls) + solver.pop() + if lit_status == "unsat": + forbidden_ids.add(nid) + + return results, solver_calls[0] \ No newline at end of file diff --git a/aria/monabs/cores/unary_check.py b/aria/monabs/cores/unary_check.py deleted file mode 100644 index 53b65317..00000000 --- a/aria/monabs/cores/unary_check.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -Unary satisfiability checks for lists of constraints. - -The helpers here answer, for each constraint under a shared precondition: -- 1: constraint is satisfiable -- 0: constraint is unsatisfiable -- 2: solver returned unknown - -Variants cover basic, incremental (push/pop), and cache-aware modes that -re-use models to mark additional constraints as satisfiable when possible. -""" - -from typing import List -import z3 - - -def unary_check(precond: z3.ExprRef, cnt_list: List[z3.ExprRef]) -> List: - """Check each constraint independently under the precondition (fresh solver per check).""" - results = [None] * len(cnt_list) - - for i, cnt in enumerate(cnt_list): - solver = z3.Solver() - solver.add(precond) # Add the precondition - solver.add(cnt) # Add the current constraint - res = solver.check() - if res == z3.sat: - results[i] = 1 - elif res == z3.unsat: - results[i] = 0 - else: - results[i] = 2 - - return results - - -def unary_check_incremental(precond: z3.ExprRef, cnt_list: List[z3.ExprRef]) -> List: - """Check each constraint with a shared solver using push/pop for efficiency.""" - results = [None] * len(cnt_list) - solver = z3.Solver() - - solver.add(precond) # Add the precondition - for i, cnt in enumerate(cnt_list): - solver.push() # Save the current state - - solver.add(cnt) # Add the current constraint - res = solver.check() - if res == z3.sat: - results[i] = 1 - elif res == z3.unsat: - results[i] = 0 - else: - results[i] = 2 - - solver.pop() # Restore the state - - return results - - -def unary_check_cached(precond: z3.ExprRef, cnt_list: List[z3.ExprRef]) -> List: - """Reuse satisfying models to mark other constraints true when implied by the model.""" - results = [None] * len(cnt_list) - - for i, cnt in enumerate(cnt_list): - if results[i] is not None: - continue - - solver = z3.Solver() - solver.add(precond) # Add the precondition - solver.add(cnt) # Add the current constraint - res = solver.check() - if res == z3.sat: - model = solver.model() - results[i] = 1 - for j, other_cnt in enumerate(cnt_list): - if results[j] is None and z3.is_true( - model.eval(other_cnt, model_completion=True) - ): - results[j] = 1 - elif res == z3.unsat: - results[i] = 0 - else: - results[i] = 2 - - return results - - -def unary_check_incremental_cached( - precond: z3.ExprRef, cnt_list: List[z3.ExprRef] -) -> List: - """Incremental + caching: share solver state and propagate model truths across constraints.""" - results = [None] * len(cnt_list) - solver = z3.Solver() - - solver.add(precond) # Add the precondition - - for i, cnt in enumerate(cnt_list): - if results[i] is not None: - continue - - solver.push() # Save the current state - - solver.add(cnt) # Add the current constraint - res = solver.check() - if res == z3.sat: - model = solver.model() - results[i] = 1 - for j, other_cnt in enumerate(cnt_list): - if results[j] is None and z3.is_true( - model.eval(other_cnt, model_completion=True) - ): - results[j] = 1 - elif res == z3.unsat: - results[i] = 0 - else: - results[i] = 2 - - solver.pop() # Restore the state - - return results diff --git a/aria/monabs/cores/unary_check_pysmt.py b/aria/monabs/cores/unary_check_pysmt.py index a6dc4151..f15d55ba 100644 --- a/aria/monabs/cores/unary_check_pysmt.py +++ b/aria/monabs/cores/unary_check_pysmt.py @@ -1,127 +1,160 @@ """ PySMT counterparts of the Z3-based unary satisfiability helpers. -Result encoding (shared across helpers): -- 1: constraint satisfiable under the given precondition -- 0: unsatisfiable -- 2: unknown (e.g., solver returned unknown) +Result encoding: +- 1: satisfiable under the precondition +- 0: unsatisfiable under the precondition +- 2: unknown (solver returned unknown or not yet decided) """ -from typing import List, Optional +from typing import Tuple, List, Optional from pysmt.exceptions import SolverReturnedUnknownResultError from pysmt.shortcuts import Solver -def _check(solver: Solver) -> str: - """ - Run the solver and normalize the outcome. +# ── Internal solver helper ─────────────────────────────────────────────────── - Returns: - "sat" | "unsat" | "unknown" +def _check(solver: Solver, solver_calls: list[int]) -> str: + """ + Normalize solver status to sat/unsat/unknown/timeout. """ + solver_calls[0] += 1 try: res = solver.solve() except SolverReturnedUnknownResultError: - return "unknown" + return "timeout" if res is True: return "sat" if res is False: return "unsat" + return "unknown" -def unary_check(precond, cnt_list: List) -> List[int]: - """Check each constraint independently under the precondition (fresh solver per check).""" - results: List[Optional[int]] = [None] * len(cnt_list) +# ── Algorithm 1: LS ─────────────────────────────────────────────────────────── + +def unary_check(precond, cnt_list: List, timeout_ms: int = 0, solver_calls: Optional[list[int]] = None) -> Tuple[List[int], int]: + """ + Check each constraint independently under the precondition (fresh solver per check). + """ + if solver_calls is None: + solver_calls = [0] + + results: List[int] = [] for i, cnt in enumerate(cnt_list): - solver = Solver() + with Solver(name="z3", solver_options={"timeout": timeout_ms}) as solver: + solver.add_assertion(precond) + solver.add_assertion(cnt) + + status = _check(solver, solver_calls) + if status == "sat": + results.append(1) + elif status == "unsat": + results.append(0) + else: # timeout or unknown + results.append(2) + + return results, solver_calls[0] + + +# ── Algorithm 2: LS-Inc ─────────────────────────────────────────────────────── + +def unary_check_incremental(precond, cnt_list: List, timeout_ms: int = 0, solver_calls: Optional[list[int]] = None) -> Tuple[List[int], int]: + """ + Check each constraint with a shared solver using push/pop for efficiency. + """ + if solver_calls is None: + solver_calls = [0] + + results: List[int] = [] + + with Solver(name="z3", solver_options={"timeout": timeout_ms}) as solver: solver.add_assertion(precond) - solver.add_assertion(cnt) - res = _check(solver) - if res == "sat": - results[i] = 1 - elif res == "unsat": - results[i] = 0 - else: - results[i] = 2 + for i, cnt in enumerate(cnt_list): + solver.push() + solver.add_assertion(cnt) + status = _check(solver, solver_calls) + if status == "sat": + results.append(1) + elif status == "unsat": + results.append(0) + else: # timeout or unknown + results.append(2) + solver.pop() - return results # type: ignore[return-value] + return results, solver_calls[0] -def unary_check_incremental(precond, cnt_list: List) -> List[int]: - """Check each constraint with a shared solver using push/pop for efficiency.""" - results: List[Optional[int]] = [None] * len(cnt_list) - solver = Solver() - solver.add_assertion(precond) +# ── Algorithm 3: LS-Reuse ─────────────────────────────────────────────────────── + +def unary_check_cached(precond, cnt_list: List, timeout_ms: int = 0, solver_calls: Optional[list[int]] = None) -> Tuple[List[int], int]: + """ + Reuse satisfying models to mark other constraints true when implied by the model. + """ + if solver_calls is None: + solver_calls = [0] - for i, cnt in enumerate(cnt_list): - solver.push() - solver.add_assertion(cnt) - res = _check(solver) - if res == "sat": - results[i] = 1 - elif res == "unsat": - results[i] = 0 - else: - results[i] = 2 - solver.pop() - - return results # type: ignore[return-value] - - -def unary_check_cached(precond, cnt_list: List) -> List[int]: - """Reuse satisfying models to mark other constraints true when implied by the model.""" results: List[Optional[int]] = [None] * len(cnt_list) for i, cnt in enumerate(cnt_list): if results[i] is not None: continue - solver = Solver() - solver.add_assertion(precond) - solver.add_assertion(cnt) - res = _check(solver) - if res == "sat": - model = solver.get_model() - results[i] = 1 - for j, other_cnt in enumerate(cnt_list): - if results[j] is None and model.get_value(other_cnt).is_true(): - results[j] = 1 - elif res == "unsat": - results[i] = 0 - else: - results[i] = 2 - - return results # type: ignore[return-value] - - -def unary_check_incremental_cached(precond, cnt_list: List) -> List[int]: - """Incremental + caching: share solver state and propagate model truths across constraints.""" + with Solver(name="z3", solver_options={"timeout": timeout_ms}) as solver: + solver.add_assertion(precond) + solver.add_assertion(cnt) + status = _check(solver, solver_calls) + + if status == "sat": + model = solver.get_model() + results[i] = 1 + for j, other_cnt in enumerate(cnt_list): + if results[j] is None and model.get_value(other_cnt).is_true(): + results[j] = 1 + elif status == "unsat": + results[i] = 0 + else: # timeout or unknown + results[i] = 2 + + return results, solver_calls[0] + + +# ── Algorithm 4: LS-IncReuse ─────────────────────────────────────────────────────── + +def unary_check_incremental_cached(precond, cnt_list: List, timeout_ms: int = 0, solver_calls: Optional[list[int]] = None) -> Tuple[List[int], int]: + """ + Incremental + caching: share solver state and propagate model truths across constraints. + """ + if solver_calls is None: + solver_calls = [0] + results: List[Optional[int]] = [None] * len(cnt_list) - solver = Solver() - solver.add_assertion(precond) - for i, cnt in enumerate(cnt_list): - if results[i] is not None: - continue + with Solver(name="z3", solver_options={"timeout": timeout_ms}) as solver: + solver.add_assertion(precond) - solver.push() - solver.add_assertion(cnt) - res = _check(solver) - if res == "sat": - model = solver.get_model() - results[i] = 1 - for j, other_cnt in enumerate(cnt_list): - if results[j] is None and model.get_value(other_cnt).is_true(): - results[j] = 1 - elif res == "unsat": - results[i] = 0 - else: - results[i] = 2 - solver.pop() - - return results # type: ignore[return-value] + for i, cnt in enumerate(cnt_list): + if results[i] is not None: + continue + + solver.push() + solver.add_assertion(cnt) + status = _check(solver, solver_calls) + + if status == "sat": + model = solver.get_model() + results[i] = 1 + for j, other_cnt in enumerate(cnt_list): + if results[j] is None and model.get_value(other_cnt).is_true(): + results[j] = 1 + elif status == "unsat": + results[i] = 0 + else: # timeout or unknown + results[i] = 2 + solver.pop() + + return results, solver_calls[0] diff --git a/aria/monabs/tests/test_pysmt_monabs.py b/aria/monabs/tests/test_pysmt_monabs.py index 4ae782e3..fdc8ac21 100644 --- a/aria/monabs/tests/test_pysmt_monabs.py +++ b/aria/monabs/tests/test_pysmt_monabs.py @@ -1,70 +1,195 @@ -"""Tests for PySMT-based monadic predicate abstraction checking functions.""" +#!/usr/bin/env python3 +"""Run MBP algorithms on SMT2 files and record results as JSON.""" -import pytest -from pysmt.shortcuts import And, Bool, Not, Or, Symbol -from pysmt.typing import BOOL +from __future__ import annotations -from aria.monabs.cores.unary_check_pysmt import ( +import argparse +import json +import os +import time +import multiprocessing as mp +from dataclasses import dataclass +from typing import Dict, List + +from cores.unary_check_pysmt import ( unary_check, unary_check_cached, unary_check_incremental, unary_check_incremental_cached, ) -from aria.monabs.cores.dis_check_pysmt import ( +from cores.dis_check_pysmt import ( disjunctive_check_cached, disjunctive_check_incremental_cached, ) -from aria.monabs.cores.con_check_pysmt import ( - conjunctive_check, - conjunctive_check_incremental, +from cores.new_check_pysmt import ( + core_lit_filter, ) +from utils.logger import setup_logger +from utils.parse_monabs_pysmt import parse_monabs_pysmt +from utils.utils import collect_smt2_files +import utils.config as cf -def _vars(): - """Create test variables.""" - return Symbol("x", BOOL), Symbol("y", BOOL) +@dataclass +class RunResult: + """Container for algorithm outputs and timings.""" -@pytest.mark.parametrize( - "func", - [ - unary_check, - unary_check_incremental, - unary_check_cached, - unary_check_incremental_cached, - ], -) -def test_unary_variants(func): - """Test unary check variants.""" - x, _ = _vars() - precond = x # forces x true, so Not(x) becomes unsat - cnts = [x, Not(x), Or(x, Bool(False))] - assert func(precond, cnts) == [1, 0, 1] + outputs: Dict[str, List[int]] + solver_calls: Dict[str, int] + times: Dict[str, float] + total_time: float + length: int = 0 -@pytest.mark.parametrize( - "func", - [disjunctive_check_cached, disjunctive_check_incremental_cached], -) -def test_disjunctive_variants(func): - """Test disjunctive check variants.""" - x, y = _vars() - precond = And(x, y) - cnts = [x, y, Not(x)] - # Under precond, x and y are satisfiable; Not(x) is contradictory. - assert func(precond, cnts) == [1, 1, 0] - - -@pytest.mark.parametrize("algo", [0, 1, 2]) -@pytest.mark.parametrize( - "func", - [conjunctive_check, conjunctive_check_incremental], -) -def test_conjunctive_variants(func, algo): - """Test conjunctive check variants with different algorithms.""" - x, y = _vars() - precond = And(x, y) - cnts = [x, y, Not(x)] - # Expect the conflicting constraint to be marked unsat. - res = func(precond, cnts, alogorithm=algo) - assert res == [1, 1, 0] +def _run_algorithms(precond, constraints, timeout_ms) -> RunResult: + outputs: Dict[str, List[int]] = {} + solver_calls: Dict[str, int] = {} + times: Dict[str, float] = {} + + start_total = time.perf_counter() + + t0 = time.perf_counter() + outputs["LS-Naive"], solver_calls["LS-Naive"] = unary_check(precond, constraints, timeout_ms) + times["LS-Naive"] = time.perf_counter() - t0 + + t0 = time.perf_counter() + outputs["LS-Inc"], solver_calls["LS-Inc"] = unary_check_incremental(precond, constraints, timeout_ms) + times["LS-Inc"] = time.perf_counter() - t0 + + t0 = time.perf_counter() + outputs["LS-Reuse"], solver_calls["LS-Reuse"] = unary_check_cached(precond, constraints, timeout_ms) + times["LS-Reuse"] = time.perf_counter() - t0 + + t0 = time.perf_counter() + outputs["LS-IncReuse"], solver_calls["LS-IncReuse"] = unary_check_incremental_cached(precond, constraints, timeout_ms) + times["LS-IncReuse"] = time.perf_counter() - t0 + + t0 = time.perf_counter() + outputs["OA"], solver_calls["OA"] = disjunctive_check_cached(precond, constraints, timeout_ms) + times["OA"] = time.perf_counter() - t0 + + t0 = time.perf_counter() + outputs["OA-Inc"], solver_calls["OA-Inc"] = disjunctive_check_incremental_cached(precond, constraints, timeout_ms) + times["OA-Inc"] = time.perf_counter() - t0 + + t0 = time.perf_counter() + outputs["CORE-LIT-FILTER"], solver_calls["CORE-LIT-FILTER"] = core_lit_filter(precond, constraints, timeout_ms) + times["CORE-LIT-FILTER"] = time.perf_counter() - t0 + + total_time = time.perf_counter() - start_total + + return RunResult(outputs=outputs, solver_calls=solver_calls, times=times, total_time=total_time, length=len(constraints)) + + +def _all_equal(values: List[List[int]]) -> bool: + if not values: + return True + first = values[0] + return all(v == first for v in values[1:]) + + +def _process_single_file(filepath: str, timeout_ms: float) -> Dict: + precond, constraints = parse_monabs_pysmt(filepath) + + if len(constraints) < cf.MIN_LENGTH: + return { + "id": filepath, + "status": "invalid", + "length": 0, + "sat_ratio": 0.0, + "total_execution_time": 0.0, + "algo_execution_time": {}, + "results": {}, + "solver_calls": {}, + } + + run = _run_algorithms(precond, constraints, timeout_ms) + + if any(2 in result for result in run.outputs.values()): # 2 stands for unknown, aka timeout + status = "timeout" + sat_ratio = -1 + else: + status = "valid" if _all_equal(list(run.outputs.values())) else "error" + if run.length > 0: + sat_ratio = sum(1 for v in run.outputs["LS-Naive"] if v == 1) / run.length + else: + sat_ratio = 0.0 + + return { + "id": filepath, + "status": status, + "length": run.length, + "sat_ratio": sat_ratio, + "total_execution_time": run.total_time, + "algo_execution_time": run.times, + "results": run.outputs, + "solver_calls": run.solver_calls, + } + + +def process_file(args) -> Dict: + filepath, timeout_ms = args + print("[Debug] Processing file:", filepath) + return _process_single_file(filepath, timeout_ms) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run MBP algorithms on SMT2 files and record JSON output." + ) + parser.add_argument( + "-i", + "--input_directory", + required=True, + help="Path to the input directory containing SMT2 files.", + ) + parser.add_argument( + "-o", + "--output_jsonl", + required=True, + help="Path to the output JSONL file.", + ) + parser.add_argument( + "-l", + "--log_file", + default='logs/test_pysmt_monabs.log', + help="Optional log file path.", + ) + parser.add_argument( + "-t", + "--timeout", + type=float, + default=30, + help="Timeout in seconds for each algorithm run.", + ) + parser.add_argument( + "-w", + "--max_workers", + type=int, + default=10, + help="Number of worker processes.", + ) + args = parser.parse_args() + + logger = setup_logger(log_file=args.log_file) + + timeout_ms = int(args.timeout * 1000) + logger.info("Set solver timeout to %d milliseconds", timeout_ms) + + smt2_files = collect_smt2_files(args.input_directory) + logger.info("Found %d SMT2 files under %s", len(smt2_files), args.input_directory) + + os.makedirs(os.path.dirname(args.output_jsonl), exist_ok=True) + with open(args.output_jsonl, "a", encoding="utf-8") as f: + with mp.Pool(processes=args.max_workers) as pool: + for result in pool.imap_unordered( + process_file, ((p, timeout_ms) for p in smt2_files) + ): + logger.info("Processed %s [%s]", result.get("id"), result.get("status")) + f.write(json.dumps(result) + "\n") + f.flush() + + +if __name__ == "__main__": + main() diff --git a/aria/monabs/utils/__init__.py b/aria/monabs/utils/__init__.py index ee6cb854..8b085fe9 100644 --- a/aria/monabs/utils/__init__.py +++ b/aria/monabs/utils/__init__.py @@ -1,5 +1,6 @@ """Utility functions for monadic predicate abstraction.""" -from .formular_generator import * -from .parse_monabs import * +from .logger import * from .parse_monabs_pysmt import * +from .utils import * +from .config import * diff --git a/aria/monabs/utils/config.py b/aria/monabs/utils/config.py new file mode 100644 index 00000000..311e6c86 --- /dev/null +++ b/aria/monabs/utils/config.py @@ -0,0 +1 @@ +MIN_LENGTH = 10 \ No newline at end of file diff --git a/aria/monabs/utils/formular_generator.py b/aria/monabs/utils/formular_generator.py deleted file mode 100644 index ef6ae095..00000000 --- a/aria/monabs/utils/formular_generator.py +++ /dev/null @@ -1,325 +0,0 @@ -""" -Randomly generate a formula using z3's Python APIs - -NOTE: This file is a quite simplified implementation - For generating more diverse and complex queries, - please refer to grammar_gene.py -""" - -import random - -import z3 - - -class FormulaGenerator: # pylint: disable=too-many-instance-attributes - """Generate random formulas using Z3's Python APIs.""" - - def __init__( - self, init_vars, bv_signed=True, bv_no_overflow=False, bv_no_underflow=False - ): - self.bools = [] - self.use_int = False - self.ints = [] - self.use_real = False - self.reals = [] - - self.use_bv = False - self.bvs = [] - # hard_bools are the cnts that must enforced - # e.g., to enforce the absence of overflow and underflow! - self.hard_bools = [] - self.bv_signed = bv_signed - self.bv_no_overflow = bv_no_overflow - self.bv_no_underflow = bv_no_underflow - - for var in init_vars: - if z3.is_int(var): - self.ints.append(var) - elif z3.is_real(var): - self.reals.append(var) - elif z3.is_bv(var): - self.bvs.append(var) - - if len(self.ints) > 0: - self.use_int = True - for _ in range(random.randint(3, 6)): - self.ints.append(FormulaGenerator.random_int()) - - if len(self.reals) > 0: - self.use_real = True - for _ in range(random.randint(3, 6)): - self.reals.append(FormulaGenerator.random_real()) - - if len(self.bvs) > 0: - self.use_bv = True - bvsort = self.bvs[0].sort() - for _ in range(random.randint(3, 6)): - self.bvs.append(z3.BitVecVal(random.randint(1, 100), bvsort.size())) - - @staticmethod - def random_int(): - """Generate a random integer value.""" - return z3.IntVal(random.randint(-100, 100)) - - @staticmethod - def random_real(): - """Generate a random real value.""" - return z3.IntVal(random.randint(-100, 100)) - - def int_from_int(self): - """Generate integer expressions from existing integers.""" - # TODO: also use constant - if len(self.ints) >= 2: - # while True: - # data = random.sample(self.ints, 2) - # i1 = data[0] - # i2 = data[1] - # if not (z3.is_int_value(i1) and z3.is_int_value(i2)): - # break - data = random.sample(self.ints, 2) - i1 = data[0] - i2 = data[1] - # [+, -, *, /, mod] - prob = random.random() - if prob <= 0.2: - self.ints.append(i1 + i2) - elif prob <= 0.4: - self.ints.append(i1 - i2) - elif prob <= 0.6: - self.ints.append(i1 * i2) - elif prob <= 0.8: - self.ints.append(i1 / i2) - else: - # is this OK? - self.ints.append(i1 % i2) - - def real_from_real(self): - """Generate real expressions from existing reals.""" - if len(self.reals) >= 2: - data = random.sample(self.reals, 2) - r1 = data[0] - r2 = data[1] - # [+, -, *, /] - prob = random.random() - if prob <= 0.25: - self.reals.append(r1 + r2) - elif prob <= 0.5: - self.reals.append(r1 - r2) - elif prob <= 0.75: - self.reals.append(r1 * r2) - else: - self.reals.append(r1 / r2) - - def bv_from_bv(self): - """Generate bit-vector expressions from existing bit-vectors.""" - if len(self.bvs) >= 2: - data = random.sample(self.bvs, 2) - r1 = data[0] - r2 = data[1] - # [+, -, *, /] - prob = random.random() - if prob <= 0.25: - self.bvs.append(r1 + r2) - if self.bv_no_overflow: - self.hard_bools.append( - z3.BVAddNoOverflow(r1, r2, signed=self.bv_signed) - ) - if self.bv_no_underflow: - self.hard_bools.append(z3.BVAddNoUnderflow(r1, r2)) - elif prob <= 0.5: - self.bvs.append(r1 - r2) - if self.bv_no_underflow: - self.hard_bools.append(z3.BVSubNoOverflow(r1, r2)) - if self.bv_no_underflow: - self.hard_bools.append( - z3.BVSubNoUnderflow(r1, r2, signed=self.bv_signed) - ) - elif prob <= 0.75: - self.bvs.append(r1 * r2) - if self.bv_no_underflow: - self.hard_bools.append( - z3.BVMulNoOverflow(r1, r2, signed=self.bv_signed) - ) - if self.bv_no_underflow: - self.hard_bools.append(z3.BVMulNoUnderflow(r1, r2)) - else: - self.bvs.append(r1 / r2) - if self.bv_signed: - self.hard_bools.append(z3.BVSDivNoOverflow(r1, r2)) - - def bool_from_int(self): - """Generate boolean expressions from integer comparisons.""" - if len(self.ints) >= 2: - # while True: - # data = random.sample(self.ints, 2) - # i1 = data[0] - # i2 = data[1] - # if not (z3.is_int_value(i1) and z3.is_int_value(i2)): - # break - data = random.sample(self.ints, 2) - i1 = data[0] - i2 = data[1] - # [<, <=, ==, >, >=, !=] - prob = random.random() - if prob <= 0.16: - new_bool = i1 < i2 - elif prob <= 0.32: - new_bool = i1 <= i2 - elif prob <= 0.48: - new_bool = i1 == i2 - elif prob <= 0.62: - new_bool = i1 > i2 - elif prob <= 0.78: - new_bool = i1 >= i2 - else: - new_bool = i1 != i2 - self.bools.append(new_bool) - - def bool_from_real(self): - """Generate boolean expressions from real comparisons.""" - if len(self.reals) >= 2: - data = random.sample(self.reals, 2) - i1 = data[0] - i2 = data[1] - # [<, <=, ==, >, >=, !=] - prob = random.random() - if prob <= 0.16: - new_bool = i1 < i2 - elif prob <= 0.32: - new_bool = i1 <= i2 - elif prob <= 0.48: - new_bool = i1 == i2 - elif prob <= 0.62: - new_bool = i1 > i2 - elif prob <= 0.78: - new_bool = i1 >= i2 - else: - new_bool = i1 != i2 - self.bools.append(new_bool) - - def bool_from_bv(self): - """Generate boolean expressions from bit-vector comparisons.""" - unsigned = not self.bv_signed - if len(self.bvs) >= 2: - data = random.sample(self.bvs, 2) - bv1 = data[0] - bv2 = data[1] - prob = random.random() - # print(bv1.sort(), bv2.sort()) - if prob <= 0.16: - if unsigned: - new_bv = z3.ULT(bv1, bv2) - else: - new_bv = bv1 < bv2 - elif prob <= 0.32: - if unsigned: - new_bv = z3.ULE(bv1, bv2) - else: - new_bv = bv1 <= bv2 - elif prob <= 0.48: - new_bv = bv1 == bv2 - elif prob <= 0.62: - if unsigned: - new_bv = z3.UGT(bv1, bv2) - else: - new_bv = bv1 > bv2 - elif prob <= 0.78: - if unsigned: - new_bv = z3.UGE(bv1, bv2) - else: - new_bv = bv1 >= bv2 - else: - new_bv = bv1 != bv2 - self.bools.append(new_bv) - - def bool_from_bool(self): - """Generate boolean expressions from existing booleans.""" - if len(self.bools) >= 2: - if random.random() < 0.22: - b = random.choice(self.bools) - self.bools.append(z3.Not(b)) - return - - data = random.sample(self.bools, 2) - b1 = data[0] - b2 = data[1] - # [and, or, xor, implies] - prob = random.random() - if prob <= 0.25: - self.bools.append(z3.And(b1, b2)) - elif prob <= 0.5: - self.bools.append(z3.Or(b1, b2)) - elif prob <= 0.75: - self.bools.append(z3.Xor(b1, b2)) - else: - self.bools.append(z3.Implies(b1, b2)) - - def generate_formula(self): - """Generate a random formula.""" - for _ in range(random.randint(3, 8)): - if self.use_int: - self.bool_from_int() - if self.use_real: - self.bool_from_real() - if self.use_bv: - self.bool_from_bv() - - for i in range(8): - if random.random() < 0.33: - if self.use_int: - self.int_from_int() - if self.use_real: - self.real_from_real() - if self.use_bv: - self.bv_from_bv() - - if random.random() < 0.33: - if self.use_int: - self.bool_from_int() - if self.use_real: - self.bool_from_real() - if self.use_bv: - self.bool_from_bv() - - if random.random() < 0.33: - self.bool_from_bool() - - max_assert = random.randint(5, 30) - res = [] - assert len(self.bools) >= 1 - for _ in range(max_assert): - clen = random.randint(1, 8) # clause length - if clen == 1: - cls = random.choice(self.bools) - else: - cls = z3.Or(random.sample(self.bools, min(len(self.bools), clen))) - res.append(cls) - - if len(self.hard_bools) > 1: - res += self.hard_bools - - if len(res) == 1: - return res[0] - return z3.And(res) - - def generate_formula_as_str(self): - """Generate formula and return as SMT-LIB string.""" - mutant = self.generate_formula() - sol = z3.Solver() - sol.add(mutant) - smt2_string = sol.to_smt2() - return smt2_string - - def get_preds(self, k): - """Get k random predicates from generated booleans.""" - res = [] - for _ in range(k): - res.append(random.choice(self.bools)) - return res - - -# if __name__ == "__main__": -# w, x, y, z = z3.Ints("w x y z") -# test = FormulaGenerator([w, x, y, z]) -# print(test.generate_formula()) -# print(test.get_preds(random.randint(50, 150))) diff --git a/aria/monabs/utils/logger.py b/aria/monabs/utils/logger.py new file mode 100644 index 00000000..9290f661 --- /dev/null +++ b/aria/monabs/utils/logger.py @@ -0,0 +1,42 @@ +"""Logging helpers for MPA utilities.""" + +from __future__ import annotations + +import logging +import os +from typing import Optional + + +def setup_logger( + name: str = "mpa", + log_file: Optional[str] = None, + level: int = logging.INFO, +) -> logging.Logger: + """Create or retrieve a configured logger. + + Args: + name: Logger name. + log_file: Optional file path to log to. + level: Logging level. + + Returns: + Configured logger. + """ + logger = logging.getLogger(name) + if logger.handlers: + return logger + + logger.setLevel(level) + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + if log_file: + os.makedirs(os.path.dirname(log_file), exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger diff --git a/aria/monabs/utils/parse_monabs.py b/aria/monabs/utils/parse_monabs.py deleted file mode 100644 index 23b71d6b..00000000 --- a/aria/monabs/utils/parse_monabs.py +++ /dev/null @@ -1,578 +0,0 @@ -""" -Parse Monadic Predicate Abstraction Queries - -Given a formula and a set of predicates {P1,...,Pn}, -decide for each Pi, whether F and Pi is satisfiable or not. -""" - -import logging -import re -import sys -from functools import reduce -from typing import List, Any - -import z3 -from z3 import ( # noqa: F401 - And, - Array, - ArraySort, - ArraySortRef, - BitVec, - BitVecSort, - BitVecSortRef, - BitVecVal, - Bool, - BoolSort, - BoolVal, - Concat, - Distinct, - Extract, - FP, - FPVal, - FPSort, - FPSortRef, - Function, - If, - Implies, - Int, - IntSort, - LShR, - Not, - Or, - Real, - RealSort, - Select, - SignExt, - Solver, - Sort, - Store, - UDiv, - UGE, - UGT, - ULE, - ULT, - URem, - ZeroExt, - Z3_BOOL_SORT, - Z3_INT_SORT, - Z3_REAL_SORT, - fpAbs, - fpAdd, - fpDiv, - fpEQ, - fpGEQ, - fpGT, - fpIsInf, - fpIsNaN, - fpIsZero, - fpLEQ, - fpLT, - fpMul, - fpNeg, - fpSub, - RNE, - SRem, -) - -# print(sys.getrecursionlimit()) -sys.setrecursionlimit(200000) # Adjust the value accordingly - -logger = logging.getLogger(__name__) - - -class MonAbsSMTLIBParser: - """Parser for Monadic Predicate Abstraction SMT-LIB queries.""" - - def __init__(self, **kwargs): - self.solver = Solver() - self.variables = {} - self.functions = {} # Store declared functions - self.let_bindings = {} # Store let bindings - # Stack of constraints for each scope level - # First list (index 0) contains global constraints - self.constraints_stack: List[List[Any]] = [[]] - - # Stack of varibles for each scope level - # For arrays and UFs, we may also need to record the scope of sorts - self.variables_stack: List[List[str]] = [[]] - - self.check_sat_results = [] # for recording the oracle - - self.logic = kwargs.get("logic", None) - - # "only_parse" mode: do not execute the check-sat commands; just record the cnts - # if we need to record the oracle, we shoud ignore the option. - self.only_parse = kwargs.get("only_parse", False) - - # extract precond: z3.ExprRef, cnt_list: List[z3.ExprRef] - self.flag: bool = False - self.cnt: z3.ExprRef = z3.BoolVal(True) - self.precond: z3.ExprRef = z3.BoolVal(True) - self.cnt_list: List[z3.ExprRef] = [] - - def flatten_sort(self, sort_expr): - """Flatten sort expression appropriately""" - if isinstance(sort_expr, list): - # If it's a single-element list containing another list, unwrap it - if len(sort_expr) == 1 and isinstance(sort_expr[0], list): - return self.flatten_sort(sort_expr[0]) - # Otherwise, keep the list structure but flatten its elements - return [ - self.flatten_sort(x) if isinstance(x, list) else x for x in sort_expr - ] - return sort_expr - - def get_sort(self, sort_expr) -> Sort: - """Parse sort expressions into z3 sorts""" - # Flatten the sort expression first - sort_expr = self.flatten_sort(sort_expr) - - if isinstance(sort_expr, str): - if sort_expr == "Bool": - return BoolSort() - if sort_expr == "Int": - return IntSort() - if sort_expr == "Real": - return RealSort() - raise ValueError(f"Unknown sort: {sort_expr}") - elif isinstance(sort_expr, list): - if sort_expr[0] == "_": - if sort_expr[1] == "BitVec": - return BitVecSort(int(sort_expr[2])) - elif sort_expr[1] == "FP" or sort_expr[1] == "FloatingPoint": - return FPSort(int(sort_expr[2]), int(sort_expr[3])) - elif sort_expr[0] == "Array": - domain = self.get_sort(sort_expr[1]) - range_sort = self.get_sort(sort_expr[2]) - return ArraySort(domain, range_sort) - raise ValueError(f"Invalid sort expression: {sort_expr}") - - def current_scope_level(self) -> int: - """Return the current scope level (0 is global scope)""" - return len(self.constraints_stack) - 1 - - def add_constraint(self, constraint, expr): - """Add a constraint to the current scope""" - self.constraints_stack[-1].append((constraint, expr)) - - def get_current_scope_constraints(self): - """Get constraints in the current scope""" - return self.constraints_stack[-1] - - def get_all_active_constraints(self): - """Get all constraints active in current scope (including parent scopes)""" - all_constraints = [] - for scope in self.constraints_stack: - all_constraints.extend(scope) - return all_constraints - - def tokenize(self, s): - """Tokenize SMT-LIB input string.""" - # Remove comments - s = re.sub(";.*\n", "\n", s) - # Add spaces around parentheses - s = s.replace("(", " ( ").replace(")", " ) ") - # Split into tokens - return [token for token in s.split() if token] - - def parse_tokens(self, tokens): - """Parse tokens into expression tree.""" - if not tokens: - return None - - if tokens[0] == "(": - expression = [] - tokens.pop(0) # Remove opening '(' - while tokens and tokens[0] != ")": - exp = self.parse_tokens(tokens) - if exp is not None: - expression.append(exp) - if tokens: - tokens.pop(0) # Remove closing ')' - return expression - return tokens.pop(0) - - def create_variable(self, name: str, sort) -> Any: - """Create a variable of the specified sort""" - z3_sort = self.get_sort(sort) - - if isinstance(z3_sort, BitVecSortRef): - return BitVec(name, z3_sort.size()) - if isinstance(z3_sort, FPSortRef): - return FP(name, z3_sort) - if isinstance(z3_sort, ArraySortRef): - return Array(name, z3_sort.domain(), z3_sort.range()) - if z3_sort.kind() == Z3_BOOL_SORT: - return Bool(name) - if z3_sort.kind() == Z3_INT_SORT: - return Int(name) - if z3_sort.kind() == Z3_REAL_SORT: - return Real(name) - raise ValueError(f"Unsupported sort kind: {z3_sort.kind()}") - - def process_command(self, command): - """Process a parsed SMT-LIB command.""" - if not isinstance(command, list): - return - - cmd = command[0] - - if cmd == "set-logic": - self.logic = command[1] - - if cmd == "declare-const": - name = command[1] - sort = command[2] if isinstance(command[2], str) else command[2:] - self.variables[name] = self.create_variable(name, sort) - - elif cmd == "declare-fun": - # Handle function declarations - name = command[1] - domain_sorts = [self.get_sort(s) for s in command[2]] - if len(domain_sorts) == 0: - sort = command[3] if isinstance(command[3], str) else command[3:] - self.variables[name] = self.create_variable(name, sort) - else: - range_sort = self.get_sort(command[3]) - self.functions[name] = Function(name, *domain_sorts, range_sort) - - # elif cmd == 'declare-sort': - # TODO: also use constant - - if cmd == "assert": - expr = self.build_expression(command[1]) - self.solver.add(expr) - # Store both the original constraint and the built z3 expression - self.add_constraint(command[1], expr) - if not self.flag: - self.precond = z3.And(self.precond, expr) - if self.flag: - self.cnt = z3.And(self.cnt, expr) - - if cmd == "push": - self.solver.push() - self.flag = True - # Create new scope for constraints - self.constraints_stack.append([]) - - if cmd == "pop": - if len(self.constraints_stack) <= 1: - raise ValueError("Cannot pop global scope") - self.solver.pop() - self.flag = False - self.cnt_list.append(self.cnt) - self.cnt = z3.BoolVal(True) - # Remove constraints from the current scope - popped_constraints = self.constraints_stack.pop() - logger.debug( - "Popped constraints from scope %s:", len(self.constraints_stack) - ) - for original, _ in popped_constraints: - logger.debug(" %s", original) - # print(f"Popped constraints from scope {len(self.constraints_stack)}:") - # for original, _ in popped_constraints: - # print(f" {original}") - - if cmd == "check-sat": - # if we set a "only parse mode", do not actually check-sat - if self.only_parse: - return - print("---------------------------------------------------") - print(self.solver) - result = self.solver.check() - print(f"check-sat result: {result}") - print("Current scope constraints:") - for original, _ in self.get_current_scope_constraints(): - print(f" {original}") - self.check_sat_results.append(result) - - def get_default_fp_sort(self): - """Get default floating point sort.""" - return FPSort(8, 24) # Single precision - - def parse_special_fp_value(self, value): - """Parse special floating point values like +oo, -oo, NaN.""" - # This is a placeholder - implement based on actual SMT-LIB format - raise NotImplementedError("Special FP value parsing not yet implemented") - - def parse_constant(self, value): - """Parse constants based on the current logic.""" - try: - if self.logic and "FP" in self.logic: - # Handle floating point constants - if value.startswith("#b"): - # Binary format - return FPVal(value[2:], self.get_default_fp_sort()) - if value.startswith("("): - # Special values like +oo, -oo, NaN - return self.parse_special_fp_value(value) - if self.logic and "BV" in self.logic: - # Handle bit-vector constants - if value.startswith("#b"): - return BitVecVal(int(value[2:], 2), len(value[2:])) - if value.startswith("#x"): - return BitVecVal(int(value[2:], 16), len(value[2:]) * 4) - - # Try parsing as regular numeric constant - return float(value) if "." in value else int(value) - except ValueError: - return value - - def build_expression(self, expr): - """Build Z3 expression from parsed SMT-LIB expression.""" - if not isinstance(expr, list): - # Handle constants and variables - if expr in self.variables: - return self.variables[expr] - if expr in self.functions: - return self.functions[expr] - if expr in self.let_bindings: - return self.let_bindings[expr] - if expr == "true": - return z3.BoolVal(True) - if expr == "false": - return z3.BoolVal(False) - return self.parse_constant(expr) - - op = expr[0] - - if isinstance(op, list): - flatten_expr = expr[0] - flatten_expr.append(expr[1]) - return self.build_special_operator(flatten_expr) - - if op == "let": - return self.build_let_expression(expr) - - args = [self.build_expression(arg) for arg in expr[1:]] - - # Theory-specific operations - if op in self.functions: - # Function application - return self.functions[op](*args) - if op == "select": - # Array select - return Select(args[0], args[1]) - if op == "store": - # Array store - return Store(args[0], args[1], args[2]) - if op.startswith("fp."): - # Floating point operations - return self.build_fp_expression(op, args) - if op.startswith("bv"): - # Bit-vector operations - return self.build_bitvector_expression(op, args) - if op == "_": - # Bit-vector constants - return self.build_special_operator(expr) - # Standard operations - return self.build_standard_expression(op, args) - - def build_special_operator(self, expr): - """Build special operators like bit-vector constants and extensions.""" - if expr[1].startswith("bv"): - # Bit-vector constant - value = int(expr[1][2:]) # Extract the value after 'bv' - width = int(expr[2]) # Extract the bit-width - return BitVecVal(value, width) - if expr[1] == "sign_extend": - # Sign extension - extension_bits = int(expr[2]) - value = self.build_expression(expr[3]) - return SignExt(extension_bits, value) - if expr[1] == "zero_extend": - # Zero extension - extension_bits = int(expr[2]) - value = self.build_expression(expr[3]) - return ZeroExt(extension_bits, value) - if expr[1] == "extract": - # Bit extraction - high = int(expr[2]) - low = int(expr[3]) - value = self.build_expression(expr[4]) - return Extract(high, low, value) - raise ValueError(f"Unknown special operator: {expr[1]}") - - def build_let_expression(self, expr): - """Build let expression with local bindings.""" - bindings = expr[1] - body = expr[2] - - # Create a mapping of variable names to their expressions - for binding in bindings: - var_name = binding[0] - var_expr = self.build_expression(binding[1]) - self.let_bindings[var_name] = var_expr - - return self.build_expression(body) - - def build_standard_expression(self, op, args): - """Build expression for standard operations.""" - if op == "+": - return sum(args) - if op == "-": - return args[0] - args[1] if len(args) == 2 else -args[0] - if op == "*": - return reduce(lambda x, y: x * y, args) - if op == "/": - return args[0] / args[1] - if op == "div": - return args[0] / args[1] - if op == ">": - return args[0] > args[1] - if op == "<": - return args[0] < args[1] - if op == ">=": - return args[0] >= args[1] - if op == "<=": - return args[0] <= args[1] - if op == "=": - return args[0] == args[1] - if op == "mod": - return args[0] % args[1] - if op == "distinct": - return Distinct(*args) - if op == "concat": - return Concat(*args) - if op == "and": - return And(*args) - if op == "or": - return Or(*args) - if op == "not": - return Not(args[0]) - if op == "=>": - return Implies(args[0], args[1]) - if op == "ite": - return If(args[0], args[1], args[2]) - raise ValueError(f"Unknown operator: {op}") - - def build_fp_expression(self, op, args): - """Build floating-point expressions.""" - rm = RNE() # Default rounding mode - if op == "fp.add": - return fpAdd(rm, args[0], args[1]) - if op == "fp.sub": - return fpSub(rm, args[0], args[1]) - if op == "fp.mul": - return fpMul(rm, args[0], args[1]) - if op == "fp.div": - return fpDiv(rm, args[0], args[1]) - if op == "fp.neg": - return fpNeg(args[0]) - if op == "fp.abs": - return fpAbs(args[0]) - if op == "fp.lt": - return fpLT(args[0], args[1]) - if op == "fp.gt": - return fpGT(args[0], args[1]) - if op == "fp.leq": - return fpLEQ(args[0], args[1]) - if op == "fp.geq": - return fpGEQ(args[0], args[1]) - if op == "fp.eq": - return fpEQ(args[0], args[1]) - if op == "fp.isNaN": - return fpIsNaN(args[0]) - if op == "fp.isInfinite": - return fpIsInf(args[0]) - if op == "fp.isZero": - return fpIsZero(args[0]) - raise ValueError(f"Unknown FP operator: {op}") - - def build_bitvector_expression(self, op, args): - """Build bit-vector expression with comprehensive operation support.""" - # Comparison operations - if op == "bvult": - return ULT(args[0], args[1]) - if op == "bvule": - return ULE(args[0], args[1]) - if op == "bvugt": - return UGT(args[0], args[1]) - if op == "bvuge": - return UGE(args[0], args[1]) - if op == "bvslt": - return args[0] < args[1] - if op == "bvsle": - return args[0] <= args[1] - if op == "bvsgt": - return args[0] > args[1] - if op == "bvsge": - return args[0] >= args[1] - - # Arithmetic operations - if op == "bvneg": - return -args[0] - if op == "bvadd": - return args[0] + args[1] - if op == "bvsub": - return args[0] - args[1] - if op == "bvmul": - return args[0] * args[1] - if op == "bvudiv": - return UDiv(args[0], args[1]) - if op == "bvsdiv": - return args[0] / args[1] - if op == "bvurem": - return URem(args[0], args[1]) - if op == "bvsrem": - return SRem(args[0], args[1]) - if op == "bvsmod": - # return SMod(args[0], args[1]) - return args[0] % args[1] - - # Bitwise operations - if op == "bvand": - return args[0] & args[1] - if op == "bvor": - return args[0] | args[1] - if op == "bvxor": - return args[0] ^ args[1] - if op == "bvnot": - return ~args[0] - if op == "bvnand": - return ~(args[0] & args[1]) - if op == "bvnor": - return ~(args[0] | args[1]) - if op == "bvxnor": - return ~(args[0] ^ args[1]) - - # Shift operations - if op == "bvshl": - return args[0] << args[1] - if op == "bvlshr": - return LShR(args[0], args[1]) - if op == "bvashr": - return args[0] >> args[1] - - raise ValueError(f"Unknown bit-vector operator: {op}") - - def parse_string(self, content): - """Parse SMT-LIB content from string.""" - tokens = self.tokenize(content) - while tokens: - command = self.parse_tokens(tokens) - if command: - self.process_command(command) - - def parse_file(self, filename): - """Parse SMT-LIB content from file.""" - with open(filename, "r", encoding="utf-8") as file: - content = file.read() - self.parse_string(content) - - def extract_scope_constraints(self): - """Extract and print scope constraints.""" - print("precond:\n->", self.precond) - print("cnt_list:") - for cnt in self.cnt_list: - print("->", cnt) - - return self.precond, self.cnt_list - - def get_precond(self): - """Get precondition.""" - return self.precond - - def get_cnt_list(self): - """Get constraint list.""" - return self.cnt_list diff --git a/aria/monabs/utils/parse_monabs_pysmt.py b/aria/monabs/utils/parse_monabs_pysmt.py index 48964886..9c84fee0 100644 --- a/aria/monabs/utils/parse_monabs_pysmt.py +++ b/aria/monabs/utils/parse_monabs_pysmt.py @@ -68,3 +68,9 @@ def parse_monabs_pysmt(filename: str) -> Tuple: """Convenience wrapper returning (precond, cnt_list).""" parser = MonAbsPySMTParser() return parser.parse_file(filename) + +# if __name__ == "__main__": +# filepath = "/home/xjn/MPA/MPA/data/bash/cons_6043.smt2" +# precond, constraints = parse_monabs_pysmt(filepath) +# print("Precondition:", precond) +# print("Number of constraints:", constraints) diff --git a/aria/monabs/utils/utils.py b/aria/monabs/utils/utils.py new file mode 100644 index 00000000..00265887 --- /dev/null +++ b/aria/monabs/utils/utils.py @@ -0,0 +1,10 @@ +import os +from typing import Dict, List + +def collect_smt2_files(root_dir: str) -> List[str]: + smt2_files: List[str] = [] + for root, _, files in os.walk(root_dir): + for name in files: + if name.endswith(".smt2"): + smt2_files.append(os.path.join(root, name)) + return smt2_files \ No newline at end of file