Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyrit/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pyrit.memory.azure_sql_memory import AzureSQLMemory
from pyrit.memory.central_memory import CentralMemory
from pyrit.memory.identifier_filters import IdentifierFilter
from pyrit.memory.memory_embedding import MemoryEmbedding
from pyrit.memory.memory_exporter import MemoryExporter
from pyrit.memory.memory_interface import MemoryInterface
Expand All @@ -26,4 +27,5 @@
"MemoryExporter",
"PromptMemoryEntry",
"SeedEntry",
"IdentifierFilter",
]
250 changes: 102 additions & 148 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,22 +250,6 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str
condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()})
return [condition]

def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any:
"""
Generate SQL condition for filtering message pieces by attack ID.

Uses JSON_VALUE() function specific to SQL Azure to query the attack identifier.

Args:
attack_id (str): The attack identifier to filter by.

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.hash') = :json_id").bindparams(
json_id=str(attack_id)
)

def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]:
"""
Generate SQL conditions for filtering by prompt metadata.
Expand Down Expand Up @@ -321,6 +305,93 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]])
"""
return self._get_metadata_conditions(prompt_metadata=metadata)[0]

def _get_condition_json_property_match(
self,
*,
json_column: Any,
property_path: str,
value_to_match: str,
partial_match: bool = False,
) -> Any:
table_name = json_column.class_.__tablename__
column_name = json_column.key

return text(
f"""ISJSON("{table_name}".{column_name}) = 1
AND LOWER(JSON_VALUE("{table_name}".{column_name}, :property_path)) {"LIKE" if partial_match else "="} :match_property_value""" # noqa: E501
).bindparams(
property_path=property_path,
match_property_value=f"%{value_to_match.lower()}%",
)

def _get_condition_json_array_match(
self,
*,
json_column: Any,
property_path: str,
sub_path: str | None = None,
array_to_match: Sequence[str],
) -> Any:
table_name = json_column.class_.__tablename__
column_name = json_column.key
if len(array_to_match) == 0:
return text(
f"""("{table_name}".{column_name} IS NULL
OR JSON_QUERY("{table_name}".{column_name}, :property_path) IS NULL
OR JSON_QUERY("{table_name}".{column_name}, :property_path) = '[]')"""
).bindparams(property_path=property_path)

value_expression = f"LOWER(JSON_VALUE(value, '{sub_path}'))" if sub_path else "LOWER(value)"

conditions = []
bindparams_dict: dict[str, str] = {"property_path": property_path}

for index, match_value in enumerate(array_to_match):
param_name = f"match_value_{index}"
conditions.append(
f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("{table_name}".{column_name},
:property_path))
WHERE {value_expression} = :{param_name})"""
)
bindparams_dict[param_name] = match_value.lower()

combined = " AND ".join(conditions)
return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict)

def _get_unique_json_array_values(
self,
*,
json_column: Any,
path_to_array: str,
sub_path: str | None = None,
) -> list[str]:
table_name = json_column.class_.__tablename__
column_name = json_column.key
with closing(self.get_session()) as session:
if sub_path is None:
rows = session.execute(
text(
f"""SELECT DISTINCT JSON_VALUE("{table_name}".{column_name}, :path_to_array) AS value
FROM "{table_name}"
WHERE ISJSON("{table_name}".{column_name}) = 1
AND JSON_VALUE("{table_name}".{column_name}, :path_to_array) IS NOT NULL"""
).bindparams(path_to_array=path_to_array)
).fetchall()
else:
rows = session.execute(
text(
f"""SELECT DISTINCT JSON_VALUE(items.value, :sub_path) AS value
FROM "{table_name}"
CROSS APPLY OPENJSON(JSON_QUERY("{table_name}".{column_name}, :path_to_array)) AS items
WHERE ISJSON("{table_name}".{column_name}) = 1
AND JSON_VALUE(items.value, :sub_path) IS NOT NULL"""
).bindparams(
path_to_array=path_to_array,
sub_path=sub_path,
)
).fetchall()
return sorted(row[0] for row in rows)

def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
"""
Get the SQL Azure implementation for filtering AttackResults by targeted harm categories.
Expand Down Expand Up @@ -388,110 +459,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
)
)

def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any:
"""
Azure SQL implementation for filtering AttackResults by attack class.
Uses JSON_VALUE() on the atomic_attack_identifier JSON column.

Args:
attack_class (str): Exact attack class name to match.

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text(
"""ISJSON("AttackResultEntries".atomic_attack_identifier) = 1
AND JSON_VALUE("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.class_name') = :attack_class"""
).bindparams(attack_class=attack_class)

def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any:
"""
Azure SQL implementation for filtering AttackResults by converter classes.

Uses JSON_VALUE()/JSON_QUERY()/OPENJSON() on the atomic_attack_identifier
JSON column.

When converter_classes is empty, matches attacks with no converters.
When non-empty, uses OPENJSON() to check all specified classes are present
(AND logic, case-insensitive).

Args:
converter_classes (Sequence[str]): List of converter class names. Empty list means no converters.

Returns:
Any: SQLAlchemy combined condition with bound parameters.
"""
if len(converter_classes) == 0:
# Explicitly "no converters": match attacks where the converter list
# is absent, null, or empty in the stored JSON.
return text(
"""("AttackResultEntries".atomic_attack_identifier IS NULL
OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.children.request_converters') IS NULL
OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.children.request_converters') = '[]')"""
)

conditions = []
bindparams_dict: dict[str, str] = {}
for i, cls in enumerate(converter_classes):
param_name = f"conv_cls_{i}"
conditions.append(
f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.children.request_converters'))
WHERE LOWER(JSON_VALUE(value, '$.class_name')) = :{param_name})"""
)
bindparams_dict[param_name] = cls.lower()

combined = " AND ".join(conditions)
return text(f"""ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 AND {combined}""").bindparams(
**bindparams_dict
)

def get_unique_attack_class_names(self) -> list[str]:
"""
Azure SQL implementation: extract unique class_name values from
the atomic_attack_identifier JSON column.

Returns:
Sorted list of unique attack class name strings.
"""
with closing(self.get_session()) as session:
rows = session.execute(
text(
"""SELECT DISTINCT JSON_VALUE(atomic_attack_identifier,
'$.children.attack.class_name') AS cls
FROM "AttackResultEntries"
WHERE ISJSON(atomic_attack_identifier) = 1
AND JSON_VALUE(atomic_attack_identifier,
'$.children.attack.class_name') IS NOT NULL"""
)
).fetchall()
return sorted(row[0] for row in rows)

def get_unique_converter_class_names(self) -> list[str]:
"""
Azure SQL implementation: extract unique converter class_name values
from the children.attack.children.request_converters array
in the atomic_attack_identifier JSON column.

Returns:
Sorted list of unique converter class name strings.
"""
with closing(self.get_session()) as session:
rows = session.execute(
text(
"""SELECT DISTINCT JSON_VALUE(c.value, '$.class_name') AS cls
FROM "AttackResultEntries"
CROSS APPLY OPENJSON(JSON_QUERY(atomic_attack_identifier,
'$.children.attack.children.request_converters')) AS c
WHERE ISJSON(atomic_attack_identifier) = 1
AND JSON_VALUE(c.value, '$.class_name') IS NOT NULL"""
)
).fetchall()
return sorted(row[0] for row in rows)

def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]:
"""
Azure SQL implementation: lightweight aggregate stats per conversation.
Expand Down Expand Up @@ -593,46 +560,33 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any
conditions.append(condition)
return and_(*conditions)

def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> TextClause:
def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None:
"""
Get the SQL Azure implementation for filtering ScenarioResults by target endpoint.

Uses JSON_VALUE() function specific to SQL Azure.

Args:
endpoint (str): The endpoint URL substring to filter by (case-insensitive).
Insert a list of message pieces into the memory storage.

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text(
"""ISJSON(objective_target_identifier) = 1
AND LOWER(JSON_VALUE(objective_target_identifier, '$.endpoint')) LIKE :endpoint"""
).bindparams(endpoint=f"%{endpoint.lower()}%")
self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces])

def _get_scenario_result_target_model_condition(self, *, model_name: str) -> TextClause:
def get_unique_attack_class_names(self) -> list[str]:
"""
Get the SQL Azure implementation for filtering ScenarioResults by target model name.

Uses JSON_VALUE() function specific to SQL Azure.

Args:
model_name (str): The model name substring to filter by (case-insensitive).
Azure SQL implementation: extract unique class_name values from
the atomic_attack_identifier JSON column.

Returns:
Any: SQLAlchemy text condition with bound parameter.
Sorted list of unique attack class name strings.
"""
return text(
"""ISJSON(objective_target_identifier) = 1
AND LOWER(JSON_VALUE(objective_target_identifier, '$.model_name')) LIKE :model_name"""
).bindparams(model_name=f"%{model_name.lower()}%")
return super().get_unique_attack_class_names()

def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None:
def get_unique_converter_class_names(self) -> list[str]:
"""
Insert a list of message pieces into the memory storage.
Azure SQL implementation: extract unique converter class_name values
from the children.attack.children.request_converters array
in the atomic_attack_identifier JSON column.

Returns:
Sorted list of unique converter class name strings.
"""
self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces])
return super().get_unique_converter_class_names()

def dispose_engine(self) -> None:
"""
Expand Down
13 changes: 13 additions & 0 deletions pyrit/memory/identifier_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import dataclass


@dataclass(frozen=True)
class IdentifierFilter:
"""Immutable filter definition for matching JSON-backed identifier properties."""

property_path: str
value_to_match: str
partial_match: bool = False
Loading
Loading