Skip to content
182 changes: 160 additions & 22 deletions src/dbtest/src/mda_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
# */


import Queue
import queue
import os
import time
import sys

MAX_TS = 99999999999999999999

class Edge:
def __init__(self, type, out):
Expand All @@ -32,7 +34,9 @@ def __init__(self, op_type, txn_num, op_time, value):
class Txn:
def __init__(self):
self.begin_ts = -1
self.end_ts = 99999999999999999999
self.end_ts = MAX_TS
self.snapshot_ts = MAX_TS


"""
Find the total variable number.
Expand Down Expand Up @@ -115,23 +119,39 @@ def set_finish_time(op_time, data_op_list, query, txn, version_list):
data_value += tmp
tmp, tmp1 = "", ""
data_value = int(data_value)
for t in txn:
for txn_num,t in enumerate(txn):
if t.begin_ts == op_time:
t.begin_ts = data_value
if snapshot_read and db_type in {"pg"}:
txn[txn_num].snapshot_ts = data_value
if t.end_ts == op_time:
t.end_ts = data_value
# update 'visible_ts' for versions installed by the committed transaction
if isolation_level != "ru" and isolation_level!="unset":
for i, data_versions in enumerate(version_list):
for j, version in enumerate(data_versions):
if version[1] == txn_num:
version_list[i][j][2] = data_value
for i, list1 in enumerate(data_op_list):
for op in list1:
if op.op_time == op_time:
op.op_time = data_value
if op.op_type == "W":
version_list[i].append(op.value)
if isolation_level == "ru" or isolation_level == "unset":
visible_ts = op.op_time
else:
visible_ts = MAX_TS
#update 'snapshot_ts'
if op.op_type == "R" or op.op_type == "P":
if snapshot_read and db_type in {"mariadb"}:
txn[op.txn_num].snapshot_ts = min(op.op_time,txn[op.txn_num].snapshot_ts)
elif op.op_type == "W":
version_list[i].append([op.value,op.txn_num,visible_ts])
op.value = len(version_list[i]) - 1
elif op.op_type == "D":
version_list[i].append(-1)
version_list[i].append([-1,op.txn_num,visible_ts])
op.value = len(version_list[i]) - 1
elif op.op_type == "I":
version_list[i].append(op.value)
version_list[i].append([op.value,op.txn_num,visible_ts])
op.value = len(version_list[i]) - 1


Expand Down Expand Up @@ -234,7 +254,7 @@ def build_graph(data_op_list, indegree, edge, txn):
None
"""
def insert_edge(data1, data2, indegree, edge, txn):
if check_concurrency(data1, data2, txn):
# if check_concurrency(data1, data2, txn):
edge_type, data1, data2 = get_edge_type(data1, data2, txn)
if edge_type != "RR" and edge_type != "RCR" and data1.txn_num != data2.txn_num:
indegree[data2.txn_num] += 1
Expand All @@ -257,7 +277,7 @@ def insert_edge(data1, data2, indegree, edge, txn):
def init_record(query, version_list):
key = find_data(query, "(")
value = find_data(query, ",")
version_list[key].append(value)
version_list[key].append((value,-1,0))


"""
Expand Down Expand Up @@ -285,14 +305,25 @@ def readVersion_record(query, op_time, data_op_list, version_list):
for op in list1:
if op.op_time == op_time:
value = op.value
if len(version_list[value]) == 0:
snapshot_ts = txn[op.txn_num].snapshot_ts
versions = version_list[value]
if len(versions) == 0:
op.value = -1
else:
if -1 not in version_list[value]:
error_message = "Value exists, but did not successully read"
return error_message
pos = version_list[value].index(-1)
op.value = pos
else:
deleted = False
for i, version in enumerate(versions):
version_val = version[0]
install_txn = version[1]
visible_ts = version[2]
if (visible_ts < snapshot_ts and visible_ts < MAX_TS) or install_txn == op.txn_num :
if version_val== -1 :
deleted = True
op.value = i
else:
error_message = "Value exists, but did not successully read"
return error_message
if not deleted:
op.value = -1
else:
for s in data:
key = find_data(s, "(")
Expand All @@ -301,14 +332,30 @@ def readVersion_record(query, op_time, data_op_list, version_list):
for op in list1:
if key == i and op.op_time == op_time:
value1 = op.value
if len(version_list[value1]) == 0:
versions = version_list[value1]
snapshot_ts = txn[op.txn_num].snapshot_ts
if len(versions) == 0:
op.value = -1
else:
if version_list[value1].count(value) == 0:
find = False
latest = 1
version_size = len(versions)
for i, version in enumerate(reversed(versions)):
version_val = version[0]
install_txn = version[1]
visible_ts = version[2]
if (visible_ts < snapshot_ts and visible_ts < MAX_TS) or install_txn == op.txn_num:
latest -= 1
if version_val== value:
find = True
op.value = version_size-i-1
break
if not find:
error_message = "Read version that does not exist"
return error_message
pos = version_list[value1].index(value)
op.value = pos
return error_message
if latest < 0 and db_type!= "unset":
error_message = "Read version that is not the latest version"
return error_message

return error_message
# for i, list1 in enumerate(data_op_list):
Expand Down Expand Up @@ -563,7 +610,7 @@ def remove_unfinished_operation(data_op_list):
"""
# toposort to determine whether there is a cycle
def check_cycle(edge, indegree, total):
q = Queue.Queue()
q = queue.Queue()
for i, degree in enumerate(indegree):
if degree == 0: q.put(i)
ans = []
Expand Down Expand Up @@ -616,6 +663,11 @@ def dfs(result_folder, ts_now, now, type):
for i in range(0, len(path)):
f.write(str(path[i]))
if i != len(path) - 1: f.write("->" + edge_type[i+1] + "->")
if db_type != "unset":
if verify_cycle(edge_type,isolation_level):
f.write(" : Accept")
else:
f.write(" : Reject")
f.write("\n\n")
path.pop()
edge_type.pop()
Expand Down Expand Up @@ -693,9 +745,95 @@ def print_error(result_folder, ts_now, error_message):
f.write("\n\n")


"""
Determines the validity of a cycle based on the edge types and the isolation level.

Args:
edge_type (list): List of edge types representing dependencies between operations.
isolation_level (str): The isolation level of the database ('ru', 'rc', 'rr', 'ser').

Returns:
bool: True if the cycle is valid, False otherwise.
"""
def verify_cycle(edge_type, isolation_level):
write_set = {"I", "D", "W"} # Write operations: Insert, Delete, Write
generalized_edge_type = [] # To store generalized edge types (RW, WR, WW, PW, WP)

# Generalize edge types based on whether they are reads (R) or writes (W)
for t in edge_type:
if t != 'null':
g_type = ""
g_type += "W" if t[0] in write_set else t[0]
g_type += "W" if t[-1] in write_set else t[-1]
generalized_edge_type.append(g_type)

# Check validity based on isolation level
if isolation_level == "ru":
# Under 'ru' (read uncommitted), cycles consisting only of write dependencies (G0) are not allowed
return generalized_edge_type.count("WW") != len(generalized_edge_type)

elif isolation_level == "rc":
# Under 'rc' (read committed), cycles consisting only of write and read dependencies (G1c) are not allowed
return generalized_edge_type.count("WW") + generalized_edge_type.count("WR") + generalized_edge_type.count("WP") != len(generalized_edge_type)

elif isolation_level == "rr":
if db_type in {"pg"}: # snapshot isolation
# Under snapshot isolation, two consecutive anti-dependency edges are allowed.
anti_set = {"RW", "PW"}
for i in range(len(generalized_edge_type) - 1):
if generalized_edge_type[i] in anti_set and generalized_edge_type[i + 1] in anti_set:
return True # Cycle is valid
return False # No valid cycle found
elif db_type in {"mariadb"}: # maridb's repeatable read is below the PL-2.99 level defined by Adya. Use PL-1 level instead.
return generalized_edge_type.count("WW") + generalized_edge_type.count("WR") + generalized_edge_type.count("WP") != len(generalized_edge_type)
else:
# Without snapshot, cycles consisting only of write, read and item-anti dependencies (G1c+G2-item) are not allowed
return generalized_edge_type.count("WW") + generalized_edge_type.count("WR") + generalized_edge_type.count("WP") + generalized_edge_type.count("RW")+ generalized_edge_type.count("PW")!= len(generalized_edge_type)

elif isolation_level == "ser":
# Under 'ser' (serializable), no cycles are allowed
return False

return True


isolations = {
"ru":"read-uncommitted",
"rc":"read-committed",
"rr":"repeatable-read",
"ser":"serializable"
}

run_result_folder = "pg/serializable"

if (len(sys.argv) - 1) == 0:
db_type = "unset"
isolation_level = "unset"
else:
db_type = sys.argv[1] #[mariadb,pg]
isolation_level = sys.argv[2] #[ru,rc,rr,ser]
if db_type in {"mariadb","pg"}:
run_result_folder = f"{db_type}/{isolations[isolation_level]}"
else:
db_type = "unset"
isolation_level = "unset"

result_folder = "check_result/" + run_result_folder
do_test_list = "do_test_list.txt"


snapshot_read = False
# If snapshot_read is False, snapshot_ts is always MAX_TS. This is equivalent to disabling snapshot reads.
# Once a version is visible, it will be read.

if db_type in {"pg"}:
if isolation_level == "ru":
isolation_level = "rc"
if isolation_level in {"rr","ser"}:
snapshot_read = True
elif db_type in {"mariadb"} and isolation_level in {"rr"}:
snapshot_read = True

#ts_now = "_2param_3txn_insert"
ts_now = time.strftime("%Y%m%d_%H%M%S", time.localtime())
if not os.path.exists(result_folder):
Expand Down