Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 137 additions & 120 deletions conflict_resolving.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
import os
from functools import cmp_to_key
from typing import List, Tuple
import random

import streamlit
import streamlit as st
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
import json

from models import Assertion
from structure import Relationship
import yaml
from pathlib import Path
with open(Path("prompts_in_conflict_resolving.yaml"), "r", encoding="utf-8") as f:
prompts_yaml = yaml.safe_load(f)


def build_prompt_from_yaml(name: str):
raw_messages = prompts_yaml["prompts"][name]["messages"]
return ChatPromptTemplate.from_messages(
[(m["role"], m["content"]) for m in raw_messages]
)

class GlobalGraph:
def __init__(self, relationships: List[Relationship]) -> None:
def __init__(self, relationships: List[Relationship], assertions: List[Assertion]) -> None:
"""
Initialize a graph-based data structure from a list of relationships.

Expand All @@ -29,6 +47,10 @@ def __init__(self, relationships: List[Relationship]) -> None:
- ordered_graph (List[str]): Assertions in traversal or topological order and final desired output.
- assertions_by_layers (dict[int, List[str]]): Assertions grouped by layers.
"""
self.assertion_table: dict[str, Assertion] = {}
for assertion in assertions:
self.assertion_table[assertion.id] = assertion

self.relationship_for_pair: dict[Tuple[str, str], Relationship] = {}
self.relationships: List[Relationship] = relationships

Expand Down Expand Up @@ -354,7 +376,6 @@ def resolve_cycles_and_conflicts(self, automatic) -> bool:
list_of_scc = self.find_nodes_in_scc()
nodes_part_of_scc = [node for sublist in list_of_scc for node in sublist]
solved = True

if len(nodes_part_of_scc) > 0 or len(self.bad_graph) > 0:
solved = False
min_node = self.pick_worst_node(nodes_part_of_scc)
Expand Down Expand Up @@ -430,6 +451,56 @@ def check_used_all_parents(self, node) -> bool:
"""
return self.number_of_visited_parents[node] == len(self.good_graph_2[node])

def get_llm_answer_with_parent(self, rel1: Relationship, rel2: Relationship):
prompt = build_prompt_from_yaml("compare_two_relations_with_parent_bool")

# --- Call the LLM ---
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, openai_api_key=os.getenv("OPENAI_API_KEY"))
chain = prompt | llm # assuming `llm` is your model
response = chain.invoke({
"parent_id": "P1",
"relation_type": "cause",
"r1": {
"assertion1_id": rel1.assertion1_id,
"assertion2_id": rel1.assertion2_id,
"confidence": 0.8,
"explanation": rel1.explanation
},
"r2": {
"assertion1_id": rel2.assertion1_id,
"assertion2_id": rel2.assertion2_id,
"confidence": 0.75,
"explanation": rel2.explanation
}
})

# --- Extract boolean ---
content = response.content.strip()
if content.startswith("```"):
content = content.split("```")[1].split("```")[0].strip()

return json.loads(content)

def get_llm_answer_without_parent(self, node1: Assertion, node2: Assertion):
prompt = build_prompt_from_yaml("compare_two_intro_assertions_bool")

# --- Call the LLM ---
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, openai_api_key=os.getenv("OPENAI_API_KEY"))
chain = prompt | llm
response = chain.invoke({
"first_id": node1.id,
"first_text": node1.content,
"second_id": node2.id,
"second_text": node2.content
})

# --- Extract boolean ---
content = response.content.strip()
if content.startswith("```"):
content = content.split("```")[1].split("```")[0].strip()

return json.loads(content)

def compare_two_relations(self, relation1: Relationship, relation2: Relationship) -> bool:
"""
Compare two Relationship objects to determine which has higher priority.
Expand All @@ -452,14 +523,14 @@ def compare_two_relations(self, relation1: Relationship, relation2: Relationship
- If the weights are equal, the comparison is deferred to LLM.
- Designed to assist in ordering or selecting relationships when constructing or analyzing the graph.
"""
relation_type1 = relation1.relationship_type
relation_type2 = relation2.relationship_type
relation_type1 = getattr(relation1, "relationship_type", None)
relation_type2 = getattr(relation2, "relationship_type", None)
weights = {"cause": 0, "condition": 1, "evidence": 2, "contrast": 3, "background": 4}
if weights[relation_type1] != weights[relation_type2]:
return weights[relation_type1] < weights[relation_type2]
else:
#TODO ask which one is better from Amirali the LLM
return True
answer = self.get_llm_answer_with_parent(relation1, relation2)
return answer

def sort_all_children(self, parent_id: str, children : List[str]) -> List[str]:
"""
Expand All @@ -483,21 +554,14 @@ def sort_all_children(self, parent_id: str, children : List[str]) -> List[str]:
- If `parent_id` is empty, the comparison is deferred to LLM.
- Helps establish an ordered structure among children for DAG traversal, visualization, or analysis.
"""
def comparison(node1: str, node2: str) -> int:
def comparison(node1: str, node2: str) -> bool:
if parent_id == "":
# TODO ask which one is better from Amirali the LLM
return 0
answer = self.get_llm_answer_without_parent(self.assertion_table[node1], self.assertion_table[node2])
return answer
relation1 = self.relationship_for_pair.get((parent_id, node1))
relation2 = self.relationship_for_pair.get((parent_id, node2))

if relation1 is None and relation2 is None:
return 0
if relation1 is None:
return -1 # or 1, depending on how you want to order
if relation2 is None:
return 1

# assume compare_two_relations returns -1, 0, or 1
if relation1 is None or relation2 is None:
return True
return self.compare_two_relations(relation1, relation2)

return sorted(children, key=cmp_to_key(comparison))
Expand Down Expand Up @@ -552,107 +616,34 @@ def ordering_assertions(self, lst: List[str], deep: int):
- Useful for topological sorting, visualization, or downstream graph analysis.
"""
for node in lst:
"""
Produce a fully ordered graph of assertions using topological sorting and layered traversal.

Args:
None

Returns:
None

Description:
- Finds all starting nodes (nodes with no incoming edges) using `find_all_starting_nodes`.
- Sorts the starting nodes according to relationship priority via `sort_all_children` (parent_id is empty for starting nodes).
- Calls `ordering_assertions` recursively to process the entire DAG:
* Updates `ordered_graph` with the traversal order.
* Updates `assertions_by_layers` with nodes grouped by layer depth.
- Returns the final flattened `ordered_graph`.
- Ensures that all dependencies (parents) are respected, producing a valid DAG ordering.
"""
self.ordered_graph.append(node)
self.assertions_by_layers.setdefault(deep, []).append(node)
children = self.get_valid_children(node)
sorted_children = self.sort_all_children(node, children)
self.ordering_assertions(sorted_children, deep + 1)

def order_the_graph(self) -> None:
"""
Produce a fully ordered graph of assertions using topological sorting and layered traversal.

Args:
None

Returns:
None

Description:
- Finds all starting nodes (nodes with no incoming edges) using `find_all_starting_nodes`.
- Sorts the starting nodes according to relationship priority via `sort_all_children` (parent_id is empty for starting nodes).
- Calls `ordering_assertions` recursively to process the entire DAG:
* Updates `ordered_graph` with the traversal order.
* Updates `assertions_by_layers` with nodes grouped by layer depth.
- Returns the final flattened `ordered_graph`.
- Ensures that all dependencies (parents) are respected, producing a valid DAG ordering.
"""
starting_nodes = self.sort_all_children("", self.find_all_starting_nodes())
self.ordering_assertions(starting_nodes, 0)

test_1 = [
# Cycle 1: A1 ↔ A2 ↔ A3 ↔ A1 (all contradictions)
Relationship(assertion1_id="A1", assertion2_id="A2",
relationship_type="contradiction", confidence=0.91,
explanation="A1 claims X holds; A2 asserts not-X."),
Relationship(assertion1_id="A2", assertion2_id="A3",
relationship_type="contradiction", confidence=0.88,
explanation="A2 negates the outcome proposed by A3."),
Relationship(assertion1_id="A3", assertion2_id="A1",
relationship_type="contradiction", confidence=0.90,
explanation="A3 rejects A1’s premise, closing the cycle."),

# Cycle 2: A4 → A5 → A6 → A4 (mixed types)
Relationship(assertion1_id="A4", assertion2_id="A5",
relationship_type="evidence", confidence=0.84,
explanation="A4 cites data that supports A5."),
Relationship(assertion1_id="A5", assertion2_id="A6",
relationship_type="cause", confidence=0.79,
explanation="A5 describes a mechanism that produces A6."),
Relationship(assertion1_id="A6", assertion2_id="A4",
relationship_type="background", confidence=0.73,
explanation="A6 provides context assumed by A4."),

# Cycle 3: A7 → A8 → A9 → A10 → A7 (with contradictions)
Relationship(assertion1_id="A7", assertion2_id="A8",
relationship_type="contradiction", confidence=0.86,
explanation="A7 states a limit exists; A8 says no such limit."),
Relationship(assertion1_id="A8", assertion2_id="A9",
relationship_type="evidence", confidence=0.76,
explanation="A8 references measurements backing A9."),
Relationship(assertion1_id="A9", assertion2_id="A10",
relationship_type="cause", confidence=0.80,
explanation="A9 implies A10 via a causal link."),
Relationship(assertion1_id="A10", assertion2_id="A7",
relationship_type="contradiction", confidence=0.82,
explanation="A10’s conclusion conflicts with A7."),

# Additional contradiction cycles to reach 20 nodes
Relationship(assertion1_id="A11", assertion2_id="A12",
relationship_type="contradiction", confidence=0.85,
explanation="A11 and A12 claim opposing outcomes."),
Relationship(assertion1_id="A12", assertion2_id="A13",
relationship_type="contradiction", confidence=0.83,
explanation="A12 and A13 disagree fundamentally."),
Relationship(assertion1_id="A13", assertion2_id="A11",
relationship_type="contradiction", confidence=0.84,
explanation="A13 refutes A11, completing a contradiction cycle."),

Relationship(assertion1_id="A14", assertion2_id="A15",
relationship_type="contradiction", confidence=0.80,
explanation="A14 predicts growth; A15 predicts decline."),
Relationship(assertion1_id="A15", assertion2_id="A16",
relationship_type="contradiction", confidence=0.82,
explanation="A15 and A16 provide incompatible claims."),
Relationship(assertion1_id="A16", assertion2_id="A14",
relationship_type="contradiction", confidence=0.81,
explanation="A16 undermines A14, closing the cycle."),

# Larger contradiction cycle: A17 → A18 → A19 → A20 → A17
Relationship(assertion1_id="A17", assertion2_id="A18",
relationship_type="contradiction", confidence=0.87,
explanation="A17 and A18 cannot both be true."),
Relationship(assertion1_id="A18", assertion2_id="A19",
relationship_type="contradiction", confidence=0.89,
explanation="A18 asserts the opposite of A19."),
Relationship(assertion1_id="A19", assertion2_id="A20",
relationship_type="contradiction", confidence=0.88,
explanation="A19 and A20 contradict each other."),
Relationship(assertion1_id="A20", assertion2_id="A17",
relationship_type="contradiction", confidence=0.90,
explanation="A20 invalidates A17, completing the cycle.")
]

test_2 = [
Relationship(assertion1_id="A1", assertion2_id="A3", relationship_type="cause", confidence=0.91, explanation="A1 causes A3."),
Relationship(assertion1_id="A1", assertion2_id="A5", relationship_type="evidence", confidence=0.82, explanation="A1 supports A5."),
Expand Down Expand Up @@ -682,16 +673,42 @@ def order_the_graph(self) -> None:
]


assertions_test_2 = [
Assertion(id="A1", content="Coffee helps me wake up in the morning.", confidence=0.86, source="synthetic everyday text"),
Assertion(id="A2", content="Drinking coffee late keeps me from sleeping.", confidence=0.80, source="synthetic everyday text"),
Assertion(id="A3", content="If I skip coffee, I feel tired during the day.", confidence=0.90, source="synthetic everyday text"),
Assertion(id="A4", content="I only drink coffee if there is milk at home.", confidence=0.82, source="synthetic everyday text"),
Assertion(id="A5", content="My friends say coffee improves concentration.", confidence=0.84, source="synthetic everyday text"),
Assertion(id="A6", content="Coffee makes me anxious sometimes.", confidence=0.78, source="synthetic everyday text"),
Assertion(id="A7", content="I started drinking coffee in college.", confidence=0.72, source="synthetic everyday text"),
Assertion(id="A8", content="Having breakfast makes coffee taste better.", confidence=0.88, source="synthetic everyday text"),
Assertion(id="A9", content="Coffee helps me study only if I sleep enough.", confidence=0.80, source="synthetic everyday text"),
Assertion(id="A10", content="I usually drink coffee while working.", confidence=0.81, source="synthetic everyday text"),
Assertion(id="A11", content="Some people say tea is healthier than coffee.", confidence=0.75, source="synthetic everyday text"),
Assertion(id="A12", content="Coffee keeps me focused during long meetings.", confidence=0.83, source="synthetic everyday text"),
Assertion(id="A13", content="Coffee gives me energy before workouts.", confidence=0.85, source="synthetic everyday text"),
Assertion(id="A14", content="I drink coffee only when I feel sleepy.", confidence=0.79, source="synthetic everyday text"),
Assertion(id="A15", content="Studies show coffee lowers some health risks.", confidence=0.84, source="synthetic everyday text"),
Assertion(id="A16", content="Too much coffee can cause headaches.", confidence=0.74, source="synthetic everyday text"),
Assertion(id="A17", content="My family always drinks coffee after dinner.", confidence=0.70, source="synthetic everyday text"),
Assertion(id="A18", content="A warm cup of coffee helps me relax.", confidence=0.87, source="synthetic everyday text"),
Assertion(id="A19", content="Coffee helps me focus only if the room is quiet.", confidence=0.80, source="synthetic everyday text"),
Assertion(id="A20", content="Coffee tastes better when shared with friends.", confidence=0.83, source="synthetic everyday text"),
]


if __name__ == "__main__":
our_graph = GlobalGraph(test_2)
our_graph.resolve_cycles_and_conflicts()
# print(our_graph.ordered_graph)
# print(our_graph.contradiction_graph)
# print(our_graph.contradiction_relations)
# print(our_graph.relationships)
# print(our_graph.contradiction_nodes)
# print(our_graph.nodes)
# print(our_graph.reverse_graph)
# print(our_graph.number_of_visited_parents)
# print(our_graph.relationship_for_pair)
our_graph = GlobalGraph(test_2, assertions_test_2)
our_graph.resolve_cycles_and_conflicts(True)

for node in our_graph.nodes:
print("For node: ", node, " we have: ")
for neigh in our_graph.good_graph_1[node]:
print(neigh)

our_graph.order_the_graph()
print("END:")
print(our_graph.ordered_graph)

for node in our_graph.ordered_graph:
print(our_graph.assertion_table[node].content)
Loading