From 6ec9e809e7188af4956c48cb4bc92a0718da08b4 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 1 Apr 2026 13:47:31 -0700 Subject: [PATCH 1/9] Introduce IdentifierFilters to allow generic DB queries on identifier properties --- pyrit/memory/__init__.py | 9 + pyrit/memory/azure_sql_memory.py | 247 +++++++---------- pyrit/memory/identifier_filters.py | 95 +++++++ pyrit/memory/memory_interface.py | 262 +++++++++++++----- pyrit/memory/sqlite_memory.py | 202 ++++++-------- .../test_interface_attack_results.py | 53 ++++ .../test_interface_prompts.py | 115 ++++++++ .../test_interface_scenario_results.py | 114 ++++++++ .../memory_interface/test_interface_scores.py | 74 +++++ 9 files changed, 822 insertions(+), 349 deletions(-) create mode 100644 pyrit/memory/identifier_filters.py diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 9f10860130..102a1f8607 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -7,6 +7,7 @@ This package defines the core `MemoryInterface` and concrete implementations for different storage backends. """ +from pyrit.memory.identifier_filters import AttackIdentifierFilter, AttackIdentifierProperty, ConverterIdentifierFilter, ConverterIdentifierProperty, ScorerIdentifierFilter, ScorerIdentifierProperty, TargetIdentifierFilter, TargetIdentifierProperty from pyrit.memory.azure_sql_memory import AzureSQLMemory from pyrit.memory.central_memory import CentralMemory from pyrit.memory.memory_embedding import MemoryEmbedding @@ -17,6 +18,10 @@ __all__ = [ "AttackResultEntry", + "AttackIdentifierFilter", + "AttackIdentifierProperty", + "ConverterIdentifierFilter", + "ConverterIdentifierProperty", "AzureSQLMemory", "CentralMemory", "SQLiteMemory", @@ -25,5 +30,9 @@ "MemoryEmbedding", "MemoryExporter", "PromptMemoryEntry", + "ScorerIdentifierFilter", + "ScorerIdentifierProperty", "SeedEntry", + "TargetIdentifierFilter", + "TargetIdentifierProperty", ] diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index c9f349c0d9..48ae2c5df2 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -250,22 +250,6 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()}) return [condition] - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Generate SQL condition for filtering message pieces by attack ID. - - Uses JSON_VALUE() function specific to SQL Azure to query the attack identifier. - - Args: - attack_id (str): The attack identifier to filter by. - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.hash') = :json_id").bindparams( - json_id=str(attack_id) - ) - def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]: """ Generate SQL conditions for filtering by prompt metadata. @@ -321,6 +305,99 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) """ return self._get_metadata_conditions(prompt_metadata=metadata)[0] + def _get_condition_json_property_match( + self, + *, + json_column: Any, + property_path: str, + value_to_match: str, + partial_match: bool = False, + ) -> Any: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + + return text( + f"""ISJSON("{table_name}".{column_name}) = 1 + AND LOWER(JSON_VALUE("{table_name}".{column_name}, :property_path)) {"LIKE" if partial_match else "="} :match_property_value""" # noqa: E501 + ).bindparams( + property_path=property_path, + match_property_value=f"%{value_to_match.lower()}%" if partial_match else value_to_match, + ) + # The above return statement already handles both partial and exact matches + # The following code is now unreachable and can be removed + + def _get_condition_json_array_match( + self, + *, + json_column: Any, + property_path: str, + array_to_match: Sequence[str], + case_insensitive: bool = False, + ) -> Any: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + if len(array_to_match) == 0: + return text( + f"""("{table_name}".{column_name} IS NULL + OR JSON_QUERY("{table_name}".{column_name}, :property_path) IS NULL + OR JSON_QUERY("{table_name}".{column_name}, :property_path) = '[]')""" + ).bindparams(property_path=property_path) + + value_expression = "JSON_VALUE(value, '$.class_name')" + if case_insensitive: + value_expression = f"LOWER({value_expression})" + + conditions = [] + bindparams_dict: dict[str, str] = {"property_path": property_path} + + for index, match_value in enumerate(array_to_match): + param_name = f"match_value_{index}" + conditions.append( + f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("{table_name}".{column_name}, + :property_path)) + WHERE {value_expression} = :{param_name})""" + ) + bindparams_dict[param_name] = match_value.lower() if case_insensitive else match_value + + combined = " AND ".join(conditions) + return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams( + **bindparams_dict + ) + + def _get_unique_json_property_values( + self, + *, + json_column: Any, + path_to_array: str, + sub_path: str | None = None, + ) -> list[str]: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + with closing(self.get_session()) as session: + if sub_path is None: + rows = session.execute( + text( + f"""SELECT DISTINCT JSON_VALUE("{table_name}".{column_name}, :path_to_array) AS value + FROM "{table_name}" + WHERE ISJSON("{table_name}".{column_name}) = 1 + AND JSON_VALUE("{table_name}".{column_name}, :path_to_array) IS NOT NULL""" + ).bindparams(path_to_array=path_to_array) + ).fetchall() + else: + rows = session.execute( + text( + f"""SELECT DISTINCT JSON_VALUE(items.value, :sub_path) AS value + FROM "{table_name}" + CROSS APPLY OPENJSON(JSON_QUERY("{table_name}".{column_name}, :path_to_array)) AS items + WHERE ISJSON("{table_name}".{column_name}) = 1 + AND JSON_VALUE(items.value, :sub_path) IS NOT NULL""" + ).bindparams( + path_to_array=path_to_array, + sub_path=sub_path, + ) + ).fetchall() + return sorted(row[0] for row in rows) + def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: """ Get the SQL Azure implementation for filtering AttackResults by targeted harm categories. @@ -388,110 +465,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) ) - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - Azure SQL implementation for filtering AttackResults by attack class. - Uses JSON_VALUE() on the atomic_attack_identifier JSON column. - - Args: - attack_class (str): Exact attack class name to match. - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text( - """ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 - AND JSON_VALUE("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.class_name') = :attack_class""" - ).bindparams(attack_class=attack_class) - - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - Azure SQL implementation for filtering AttackResults by converter classes. - - Uses JSON_VALUE()/JSON_QUERY()/OPENJSON() on the atomic_attack_identifier - JSON column. - - When converter_classes is empty, matches attacks with no converters. - When non-empty, uses OPENJSON() to check all specified classes are present - (AND logic, case-insensitive). - - Args: - converter_classes (Sequence[str]): List of converter class names. Empty list means no converters. - - Returns: - Any: SQLAlchemy combined condition with bound parameters. - """ - if len(converter_classes) == 0: - # Explicitly "no converters": match attacks where the converter list - # is absent, null, or empty in the stored JSON. - return text( - """("AttackResultEntries".atomic_attack_identifier IS NULL - OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') IS NULL - OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') = '[]')""" - ) - - conditions = [] - bindparams_dict: dict[str, str] = {} - for i, cls in enumerate(converter_classes): - param_name = f"conv_cls_{i}" - conditions.append( - f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters')) - WHERE LOWER(JSON_VALUE(value, '$.class_name')) = :{param_name})""" - ) - bindparams_dict[param_name] = cls.lower() - - combined = " AND ".join(conditions) - return text(f"""ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 AND {combined}""").bindparams( - **bindparams_dict - ) - - def get_unique_attack_class_names(self) -> list[str]: - """ - Azure SQL implementation: extract unique class_name values from - the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique attack class name strings. - """ - with closing(self.get_session()) as session: - rows = session.execute( - text( - """SELECT DISTINCT JSON_VALUE(atomic_attack_identifier, - '$.children.attack.class_name') AS cls - FROM "AttackResultEntries" - WHERE ISJSON(atomic_attack_identifier) = 1 - AND JSON_VALUE(atomic_attack_identifier, - '$.children.attack.class_name') IS NOT NULL""" - ) - ).fetchall() - return sorted(row[0] for row in rows) - - def get_unique_converter_class_names(self) -> list[str]: - """ - Azure SQL implementation: extract unique converter class_name values - from the children.attack.children.request_converters array - in the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique converter class name strings. - """ - with closing(self.get_session()) as session: - rows = session.execute( - text( - """SELECT DISTINCT JSON_VALUE(c.value, '$.class_name') AS cls - FROM "AttackResultEntries" - CROSS APPLY OPENJSON(JSON_QUERY(atomic_attack_identifier, - '$.children.attack.children.request_converters')) AS c - WHERE ISJSON(atomic_attack_identifier) = 1 - AND JSON_VALUE(c.value, '$.class_name') IS NOT NULL""" - ) - ).fetchall() - return sorted(row[0] for row in rows) - def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]: """ Azure SQL implementation: lightweight aggregate stats per conversation. @@ -593,40 +566,6 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any conditions.append(condition) return and_(*conditions) - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> TextClause: - """ - Get the SQL Azure implementation for filtering ScenarioResults by target endpoint. - - Uses JSON_VALUE() function specific to SQL Azure. - - Args: - endpoint (str): The endpoint URL substring to filter by (case-insensitive). - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text( - """ISJSON(objective_target_identifier) = 1 - AND LOWER(JSON_VALUE(objective_target_identifier, '$.endpoint')) LIKE :endpoint""" - ).bindparams(endpoint=f"%{endpoint.lower()}%") - - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> TextClause: - """ - Get the SQL Azure implementation for filtering ScenarioResults by target model name. - - Uses JSON_VALUE() function specific to SQL Azure. - - Args: - model_name (str): The model name substring to filter by (case-insensitive). - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text( - """ISJSON(objective_target_identifier) = 1 - AND LOWER(JSON_VALUE(objective_target_identifier, '$.model_name')) LIKE :model_name""" - ).bindparams(model_name=f"%{model_name.lower()}%") - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py new file mode 100644 index 0000000000..8792f03241 --- /dev/null +++ b/pyrit/memory/identifier_filters.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABC +from dataclasses import dataclass +from enum import Enum +from typing import Generic, TypeVar + + +# TODO: if/when we move to python 3.11+, we can replace this with StrEnum +class _StrEnum(str, Enum): + """Base class that mimics StrEnum behavior for Python < 3.11.""" + + def __str__(self) -> str: + return self.value + + +T = TypeVar("T", bound=_StrEnum) + + +class IdentifierProperty(_StrEnum): + """Allowed JSON paths for identifier filtering.""" + + HASH = "$.hash" + CLASS_NAME = "$.class_name" + + +@dataclass(frozen=True) +class IdentifierFilter(ABC, Generic[T]): + """Immutable filter definition for matching JSON-backed identifier properties.""" + + property_path: T | str + value_to_match: str + partial_match: bool = False + + def __post_init__(self) -> None: + """Normalize and validate the configured property path.""" + object.__setattr__(self, "property_path", str(self.property_path)) + + +class AttackIdentifierProperty(_StrEnum): + """Allowed JSON paths for attack identifier filtering.""" + + HASH = "$.hash" + ATTACK_CLASS_NAME = "$.children.attack.class_name" + REQUEST_CONVERTERS = "$.children.attack.children.request_converters" + + +class TargetIdentifierProperty(_StrEnum): + """Allowed JSON paths for target identifier filtering.""" + + HASH = "$.hash" + ENDPOINT = "$.endpoint" + MODEL_NAME = "$.model_name" + + +class ConverterIdentifierProperty(_StrEnum): + """Allowed JSON paths for converter identifier filtering.""" + + HASH = "$.hash" + CLASS_NAME = "$.class_name" + + +class ScorerIdentifierProperty(_StrEnum): + """Allowed JSON paths for scorer identifier filtering.""" + + HASH = "$.hash" + CLASS_NAME = "$.class_name" + + +@dataclass(frozen=True) +class AttackIdentifierFilter(IdentifierFilter[AttackIdentifierProperty]): + """ + Immutable filter definition for matching JSON-backed attack identifier properties. + + Args: + property_path: The JSON path of the property to filter on. + value_to_match: The value to match against the property. + partial_match: Whether to allow partial matches (default: False). + """ + + +@dataclass(frozen=True) +class TargetIdentifierFilter(IdentifierFilter[TargetIdentifierProperty]): + """Immutable filter definition for matching JSON-backed target identifier properties.""" + + +@dataclass(frozen=True) +class ConverterIdentifierFilter(IdentifierFilter[ConverterIdentifierProperty]): + """Immutable filter definition for matching JSON-backed converter identifier properties.""" + + +@dataclass(frozen=True) +class ScorerIdentifierFilter(IdentifierFilter[ScorerIdentifierProperty]): + """Immutable filter definition for matching JSON-backed scorer identifier properties.""" diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 90322ebec4..5bc1f4ad3e 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -19,6 +19,14 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from pyrit.common.path import DB_DATA_PATH +from pyrit.memory.identifier_filters import ( + AttackIdentifierFilter, + AttackIdentifierProperty, + ConverterIdentifierProperty, + ScorerIdentifierFilter, + TargetIdentifierFilter, + TargetIdentifierProperty, +) from pyrit.memory.memory_embedding import ( MemoryEmbedding, default_memory_embedding_factory, @@ -113,6 +121,77 @@ def disable_embedding(self) -> None: """ self.memory_embedding = None + @abc.abstractmethod + def _get_condition_json_property_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + value_to_match: str, + partial_match: bool = False, + ) -> Any: + """ + Return a database-specific condition for matching a value at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. + property_path (str): The JSON path for the property to match. + value_to_match (str): The string value that must match the extracted JSON property value. + partial_match (bool): Whether to perform a case-insensitive substring match. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + """ + + @abc.abstractmethod + def _get_condition_json_array_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + array_to_match: Sequence[str], + case_insensitive: bool = False, + ) -> Any: + """ + Return a database-specific condition for matching an array at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. + property_path (str): The JSON path for the target array. + array_to_match (Sequence[str]): The array that must match the extracted JSON array values. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. + case_insensitive (bool): Whether string comparison should ignore casing. + + Returns: + Any: A database-specific SQLAlchemy condition. + """ + + @abc.abstractmethod + def _get_unique_json_array_values( + self, + *, + json_column: Any, + path_to_array: str, + sub_path: str | None = None, + ) -> list[str]: + """ + Return sorted unique values in an array located at a given path within a JSON object. + + This method performs a database-level query to extract distinct values from a + an array within a JSON-type column. When ``sub_path`` is provided, the distinct values are + extracted from each array item using the sub-path. + + Args: + json_column (Any): The JSON-backed model field to query. + path_to_array (str): The JSON path to the array whose unique values are extracted. + sub_path (str | None): Optional JSON path applied to each array + item before collecting distinct values. + + Returns: + list[str]: A sorted list of unique values in the array. + """ + @abc.abstractmethod def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: """ @@ -155,12 +234,6 @@ def _get_message_pieces_prompt_metadata_conditions( list: A list of conditions for filtering memory entries based on prompt metadata. """ - @abc.abstractmethod - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Return a condition to retrieve based on attack ID. - """ - @abc.abstractmethod def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: """ @@ -289,41 +362,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: Database-specific SQLAlchemy condition. """ - @abc.abstractmethod - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - Return a database-specific condition for filtering AttackResults by attack class - (class_name in the attack_identifier JSON column). - - Args: - attack_class: Exact attack class name to match. - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - Return a database-specific condition for filtering AttackResults by converter classes - in the request_converter_identifiers array within attack_identifier JSON column. - - This method is only called when converter filtering is requested (converter_classes - is not None). The caller handles the None-vs-list distinction: - - - ``len(converter_classes) == 0``: return a condition matching attacks with NO converters. - - ``len(converter_classes) > 0``: return a condition requiring ALL specified converter - class names to be present (AND logic, case-insensitive). - - Args: - converter_classes: Converter class names to require. An empty sequence means - "match only attacks that have no converters". - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod def get_unique_attack_class_names(self) -> list[str]: """ Return sorted unique attack class names from all stored attack results. @@ -334,8 +372,11 @@ def get_unique_attack_class_names(self) -> list[str]: Returns: Sorted list of unique attack class name strings. """ + return self._get_unique_json_array_values( + json_column=AttackResultEntry.atomic_attack_identifier, + path_to_array=AttackIdentifierProperty.ATTACK_CLASS_NAME + ) - @abc.abstractmethod def get_unique_converter_class_names(self) -> list[str]: """ Return sorted unique converter class names used across all attack results. @@ -346,6 +387,11 @@ def get_unique_converter_class_names(self) -> list[str]: Returns: Sorted list of unique converter class name strings. """ + return self._get_unique_json_array_values( + json_column=AttackResultEntry.atomic_attack_identifier, + path_to_array=AttackIdentifierProperty.REQUEST_CONVERTERS, + sub_path=ConverterIdentifierProperty.CLASS_NAME, + ) @abc.abstractmethod def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, "ConversationStats"]: @@ -377,30 +423,6 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any Database-specific SQLAlchemy condition. """ - @abc.abstractmethod - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any: - """ - Return a database-specific condition for filtering ScenarioResults by target endpoint. - - Args: - endpoint: Endpoint substring to search for (case-insensitive). - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any: - """ - Return a database-specific condition for filtering ScenarioResults by target model name. - - Args: - model_name: Model name substring to search for (case-insensitive). - - Returns: - Database-specific SQLAlchemy condition. - """ - def add_scores_to_memory(self, *, scores: Sequence[Score]) -> None: """ Insert a list of scores into the memory storage. @@ -425,6 +447,7 @@ def get_scores( score_category: Optional[str] = None, sent_after: Optional[datetime] = None, sent_before: Optional[datetime] = None, + scorer_identifier_filter: Optional[ScorerIdentifierFilter] = None, ) -> Sequence[Score]: """ Retrieve a list of Score objects based on the specified filters. @@ -435,6 +458,8 @@ def get_scores( score_category (Optional[str]): The category of the score to filter by. sent_after (Optional[datetime]): Filter for scores sent after this datetime. sent_before (Optional[datetime]): Filter for scores sent before this datetime. + scorer_identifier_filter (Optional[ScorerIdentifierFilter]): A ScorerIdentifierFilter object that + allows filtering by various scorer identifier JSON properties. Defaults to None. Returns: Sequence[Score]: A list of Score objects that match the specified filters. @@ -451,6 +476,15 @@ def get_scores( conditions.append(ScoreEntry.timestamp >= sent_after) if sent_before: conditions.append(ScoreEntry.timestamp <= sent_before) + if scorer_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=ScoreEntry.scorer_class_identifier, + property_path=scorer_identifier_filter.property_path, + value_to_match=scorer_identifier_filter.value_to_match, + partial_match=scorer_identifier_filter.partial_match, + ) + ) if not conditions: return [] @@ -581,6 +615,8 @@ def get_message_pieces( data_type: Optional[str] = None, not_data_type: Optional[str] = None, converted_value_sha256: Optional[Sequence[str]] = None, + attack_identifier_filter: Optional[AttackIdentifierFilter] = None, + prompt_target_identifier_filter: Optional[TargetIdentifierFilter] = None, ) -> Sequence[MessagePiece]: """ Retrieve a list of MessagePiece objects based on the specified filters. @@ -602,6 +638,12 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. + attack_identifier_filter (Optional[AttackIdentifierFilter], optional): + An AttackIdentifierFilter object that + allows filtering by various attack identifier JSON properties. Defaults to None. + prompt_target_identifier_filter (Optional[TargetIdentifierFilter], optional): + A TargetIdentifierFilter object that + allows filtering by various target identifier JSON properties. Defaults to None. Returns: Sequence[MessagePiece]: A list of MessagePiece objects that match the specified filters. @@ -612,7 +654,13 @@ def get_message_pieces( """ conditions = [] if attack_id: - conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.attack_identifier, + property_path=AttackIdentifierProperty.HASH, + value_to_match=str(attack_id) + ) + ) if role: conditions.append(PromptMemoryEntry.role == role) if conversation_id: @@ -638,6 +686,24 @@ def get_message_pieces( conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) if converted_value_sha256: conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) + if attack_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.attack_identifier, + property_path=attack_identifier_filter.property_path, + value_to_match=attack_identifier_filter.value_to_match, + partial_match=attack_identifier_filter.partial_match, + ) + ) + if prompt_target_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.prompt_target_identifier, + property_path=prompt_target_identifier_filter.property_path, + value_to_match=prompt_target_identifier_filter.value_to_match, + partial_match=prompt_target_identifier_filter.partial_match, + ) + ) try: memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( @@ -1365,6 +1431,7 @@ def get_attack_results( converter_classes: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, + attack_identifier_filter: Optional[AttackIdentifierFilter] = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. @@ -1392,6 +1459,9 @@ def get_attack_results( labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. + attack_identifier_filter (Optional[AttackIdentifierFilter], optional): + An AttackIdentifierFilter object that allows filtering by various attack identifier + JSON properties. Defaults to None. Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. @@ -1415,12 +1485,25 @@ def get_attack_results( if attack_class: # Use database-specific JSON query method - conditions.append(self._get_attack_result_attack_class_condition(attack_class=attack_class)) + conditions.append( + self._get_condition_json_property_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path=AttackIdentifierProperty.ATTACK_CLASS_NAME, + value_to_match=attack_class, + ) + ) if converter_classes is not None: # converter_classes=[] means "only attacks with no converters" # converter_classes=["A","B"] means "must have all listed converters" - conditions.append(self._get_attack_result_converter_classes_condition(converter_classes=converter_classes)) + conditions.append( + self._get_condition_json_array_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path=AttackIdentifierProperty.REQUEST_CONVERTERS, + array_to_match=converter_classes, + case_insensitive=True, + ) + ) if targeted_harm_categories: # Use database-specific JSON query method @@ -1432,6 +1515,16 @@ def get_attack_results( # Use database-specific JSON query method conditions.append(self._get_attack_result_label_condition(labels=labels)) + if attack_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path=attack_identifier_filter.property_path, + value_to_match=attack_identifier_filter.value_to_match, + partial_match=attack_identifier_filter.partial_match, + ) + ) + try: entries: Sequence[AttackResultEntry] = self._query_entries( AttackResultEntry, conditions=and_(*conditions) if conditions else None @@ -1612,6 +1705,7 @@ def get_scenario_results( labels: Optional[dict[str, str]] = None, objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, + objective_target_identifier_filter: Optional[TargetIdentifierFilter] = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. @@ -1635,6 +1729,8 @@ def get_scenario_results( objective_target_model_name (Optional[str], optional): Filter for scenarios where the objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. + objective_target_identifier_filter (Optional[TargetIdentifierFilter], optional): + A TargetIdentifierFilter object that allows filtering by various target identifier JSON properties. Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. @@ -1672,11 +1768,35 @@ def get_scenario_results( if objective_target_endpoint: # Use database-specific JSON query method - conditions.append(self._get_scenario_result_target_endpoint_condition(endpoint=objective_target_endpoint)) + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path=TargetIdentifierProperty.ENDPOINT, + value_to_match=objective_target_endpoint, + partial_match=True, + ) + ) if objective_target_model_name: # Use database-specific JSON query method - conditions.append(self._get_scenario_result_target_model_condition(model_name=objective_target_model_name)) + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path=TargetIdentifierProperty.MODEL_NAME, + value_to_match=objective_target_model_name, + partial_match=True, + ) + ) + + if objective_target_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path=objective_target_identifier_filter.property_path, + value_to_match=objective_target_identifier_filter.value_to_match, + partial_match=objective_target_identifier_filter.partial_match, + ) + ) try: entries: Sequence[ScenarioResultEntry] = self._query_entries( diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 7bd05b4f82..a41dbffc90 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -177,15 +177,6 @@ def _get_message_pieces_prompt_metadata_conditions( condition = text(json_conditions).bindparams(**{key: str(value) for key, value in prompt_metadata.items()}) return [condition] - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Generate SQLAlchemy filter conditions for filtering by attack ID. - - Returns: - Any: A SQLAlchemy text condition with bound parameters. - """ - return text("JSON_EXTRACT(attack_identifier, '$.hash') = :attack_id").bindparams(attack_id=str(attack_id)) - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: """ Generate SQLAlchemy filter conditions for filtering seed prompts by metadata. @@ -199,6 +190,84 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) # Note: We do NOT convert values to string here, to allow integer comparison in JSON return text(json_conditions).bindparams(**dict(metadata.items())) + def _get_condition_json_property_match( + self, + *, + json_column: Any, + property_path: str, + value_to_match: str, + partial_match: bool = False, + ) -> Any: + extracted_value = func.json_extract(json_column, property_path) + if partial_match: + return func.lower(extracted_value).like(f"%{value_to_match.lower()}%") + return extracted_value == value_to_match + + def _get_condition_json_array_match( + self, + *, + json_column: Any, + property_path: str, + array_to_match: Sequence[str], + case_insensitive: bool = False, + ) -> Any: + array_expr = func.json_extract(json_column, property_path) + if len(array_to_match) == 0: + return or_( + json_column.is_(None), + array_expr.is_(None), + array_expr == "[]", + ) + + table_name = json_column.class_.__tablename__ + column_name = json_column.key + value_expression = "json_extract(value, '$.class_name')" + if case_insensitive: + value_expression = f"LOWER({value_expression})" + + conditions = [] + for index, match_value in enumerate(array_to_match): + param_name = f"match_value_{index}" + bind_params: dict[str, str] = { + "property_path": property_path, + param_name: match_value.lower() if case_insensitive else match_value, + } + conditions.append( + text( + f'''EXISTS(SELECT 1 FROM json_each( + json_extract("{table_name}".{column_name}, :property_path)) + WHERE {value_expression} = :{param_name})''' + ).bindparams(**bind_params) + ) + return and_(*conditions) + + def _get_unique_json_array_values( + self, + *, + json_column: Any, + path_to_array: str, + sub_path: str | None = None, + ) -> list[str]: + with closing(self.get_session()) as session: + if sub_path is None: + property_expr = func.json_extract(json_column, path_to_array) + rows = session.query(property_expr).filter(property_expr.isnot(None)).distinct().all() + else: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + rows = session.execute( + text( + f'''SELECT DISTINCT json_extract(j.value, :sub_path) AS value + FROM "{table_name}", + json_each(json_extract("{table_name}".{column_name}, :path_to_array)) AS j + WHERE json_extract(j.value, :sub_path) IS NOT NULL''' + ).bindparams( + path_to_array=path_to_array, + sub_path=sub_path, + ) + ).fetchall() + return sorted(row[0] for row in rows) + def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. @@ -526,97 +595,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) return labels_subquery # noqa: RET504 - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - SQLite implementation for filtering AttackResults by attack class. - Uses json_extract() on the atomic_attack_identifier JSON column. - - Returns: - Any: A SQLAlchemy condition for filtering by attack class. - """ - return ( - func.json_extract(AttackResultEntry.atomic_attack_identifier, "$.children.attack.class_name") - == attack_class - ) - - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - SQLite implementation for filtering AttackResults by converter classes. - - Uses json_extract() on the atomic_attack_identifier JSON column. - - When converter_classes is empty, matches attacks with no converters - (children.attack.children.request_converters is absent or null in the JSON). - When non-empty, uses json_each() to check all specified classes are present - (AND logic, case-insensitive). - - Returns: - Any: A SQLAlchemy condition for filtering by converter classes. - """ - if len(converter_classes) == 0: - # Explicitly "no converters": match attacks where the converter list - # is absent, null, or empty in the stored JSON. - converter_json = func.json_extract( - AttackResultEntry.atomic_attack_identifier, - "$.children.attack.children.request_converters", - ) - return or_( - AttackResultEntry.atomic_attack_identifier.is_(None), - converter_json.is_(None), - converter_json == "[]", - ) - - conditions = [] - for i, cls in enumerate(converter_classes): - param_name = f"conv_cls_{i}" - conditions.append( - text( - f"""EXISTS(SELECT 1 FROM json_each( - json_extract("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters')) - WHERE LOWER(json_extract(value, '$.class_name')) = :{param_name})""" - ).bindparams(**{param_name: cls.lower()}) - ) - return and_(*conditions) - - def get_unique_attack_class_names(self) -> list[str]: - """ - SQLite implementation: extract unique class_name values from - the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique attack class name strings. - """ - with closing(self.get_session()) as session: - class_name_expr = func.json_extract( - AttackResultEntry.atomic_attack_identifier, "$.children.attack.class_name" - ) - rows = session.query(class_name_expr).filter(class_name_expr.isnot(None)).distinct().all() - return sorted(row[0] for row in rows) - - def get_unique_converter_class_names(self) -> list[str]: - """ - SQLite implementation: extract unique converter class_name values - from the children.attack.children.request_converters array in the - atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique converter class name strings. - """ - with closing(self.get_session()) as session: - rows = session.execute( - text( - """SELECT DISTINCT json_extract(j.value, '$.class_name') AS cls - FROM "AttackResultEntries", - json_each( - json_extract("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') - ) AS j - WHERE cls IS NOT NULL""" - ) - ).fetchall() - return sorted(row[0] for row in rows) - def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]: """ SQLite implementation: lightweight aggregate stats per conversation. @@ -710,27 +688,3 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any return and_( *[func.json_extract(ScenarioResultEntry.labels, f"$.{key}") == value for key, value in labels.items()] ) - - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any: - """ - SQLite implementation for filtering ScenarioResults by target endpoint. - Uses json_extract() function specific to SQLite. - - Returns: - Any: A SQLAlchemy subquery for filtering by target endpoint. - """ - return func.lower(func.json_extract(ScenarioResultEntry.objective_target_identifier, "$.endpoint")).like( - f"%{endpoint.lower()}%" - ) - - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any: - """ - SQLite implementation for filtering ScenarioResults by target model name. - Uses json_extract() function specific to SQLite. - - Returns: - Any: A SQLAlchemy subquery for filtering by target model name. - """ - return func.lower(func.json_extract(ScenarioResultEntry.objective_target_identifier, "$.model_name")).like( - f"%{model_name.lower()}%" - ) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 91367c3a1c..de238952f4 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -9,6 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory import MemoryInterface +from pyrit.memory.identifier_filters import AttackIdentifierFilter, AttackIdentifierProperty from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( AttackOutcome, @@ -1352,3 +1353,55 @@ def test_get_unique_converter_class_names_skips_no_converters(sqlite_instance: M result = sqlite_instance.get_unique_converter_class_names() assert result == ["Base64Converter"] + + +def test_get_attack_results_by_attack_identifier_filter_hash(sqlite_instance: MemoryInterface): + """Test filtering attack results by AttackIdentifierFilter with hash.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + # Filter by hash of ar1's attack identifier + results = sqlite_instance.get_attack_results( + attack_identifier_filter=AttackIdentifierFilter( + property_path=AttackIdentifierProperty.HASH, + value_to_match=ar1.atomic_attack_identifier.hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].conversation_id == "conv_1" + + +def test_get_attack_results_by_attack_identifier_filter_class_name(sqlite_instance: MemoryInterface): + """Test filtering attack results by AttackIdentifierFilter with class_name.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + # Filter by partial attack class name + results = sqlite_instance.get_attack_results( + attack_identifier_filter=AttackIdentifierFilter( + property_path=AttackIdentifierProperty.ATTACK_CLASS_NAME, + value_to_match="Crescendo", + partial_match=True, + ), + ) + assert len(results) == 2 + assert {r.conversation_id for r in results} == {"conv_1", "conv_3"} + + +def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance: MemoryInterface): + """Test that AttackIdentifierFilter returns empty when nothing matches.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) + + results = sqlite_instance.get_attack_results( + attack_identifier_filter=AttackIdentifierFilter( + property_path=AttackIdentifierProperty.HASH, + value_to_match="nonexistent_hash", + partial_match=False, + ), + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 67a4292f87..457169b911 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -14,6 +14,12 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry +from pyrit.memory.identifier_filters import ( + AttackIdentifierFilter, + AttackIdentifierProperty, + TargetIdentifierFilter, + TargetIdentifierProperty, +) from pyrit.models import ( Message, MessagePiece, @@ -1248,3 +1254,112 @@ def test_get_request_from_response_raises_error_for_sequence_less_than_one(sqlit with pytest.raises(ValueError, match="The provided request does not have a preceding request \\(sequence < 1\\)."): sqlite_instance.get_request_from_response(response=response_without_request) + + +def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryInterface): + attack1 = PromptSendingAttack(objective_target=get_mock_target()) + attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) + + entries = [ + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello 1", + attack_identifier=attack1.get_identifier(), + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="assistant", + original_value="Hello 2", + attack_identifier=attack2.get_identifier(), + ) + ), + ] + + sqlite_instance._insert_entries(entries=entries) + + # Filter by exact attack hash + results = sqlite_instance.get_message_pieces( + attack_identifier_filter=AttackIdentifierFilter( + property_path=AttackIdentifierProperty.HASH, + value_to_match=attack1.get_identifier().hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].original_value == "Hello 1" + + # No match + results = sqlite_instance.get_message_pieces( + attack_identifier_filter=AttackIdentifierFilter( + property_path=AttackIdentifierProperty.HASH, + value_to_match="nonexistent_hash", + partial_match=False, + ), + ) + assert len(results) == 0 + + +def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryInterface): + target_id_1 = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="AzureChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + entries = [ + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello OpenAI", + prompt_target_identifier=target_id_1, + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello Azure", + prompt_target_identifier=target_id_2, + ) + ), + ] + + sqlite_instance._insert_entries(entries=entries) + + # Filter by target hash + results = sqlite_instance.get_message_pieces( + prompt_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.HASH, + value_to_match=target_id_1.hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].original_value == "Hello OpenAI" + + # Filter by endpoint partial match + results = sqlite_instance.get_message_pieces( + prompt_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.ENDPOINT, + value_to_match="openai", + partial_match=True, + ), + ) + assert len(results) == 1 + assert results[0].original_value == "Hello OpenAI" + + # No match + results = sqlite_instance.get_message_pieces( + prompt_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.HASH, + value_to_match="nonexistent", + partial_match=False, + ), + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index e513e8b873..51b64a819b 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -9,6 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface +from pyrit.memory.identifier_filters import TargetIdentifierFilter, TargetIdentifierProperty from pyrit.models import ( AttackOutcome, AttackResult, @@ -645,3 +646,116 @@ def test_combined_filters(sqlite_instance: MemoryInterface): assert len(results) == 1 assert results[0].scenario_identifier.pyrit_version == "0.5.0" assert "gpt-4" in results[0].objective_target_identifier.params["model_name"] + + +def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: MemoryInterface): + """Test filtering scenario results by TargetIdentifierFilter with hash.""" + target_id_1 = ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="Azure", + class_module="test", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + attack_result1 = create_attack_result("conv_1", "Objective 1") + attack_result2 = create_attack_result("conv_2", "Objective 2") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario OpenAI", scenario_version=1), + objective_target_identifier=target_id_1, + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + scenario2 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario Azure", scenario_version=1), + objective_target_identifier=target_id_2, + attack_results={"Attack2": [attack_result2]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1, scenario2]) + + # Filter by target hash + results = sqlite_instance.get_scenario_results( + objective_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.HASH, + value_to_match=target_id_1.hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].scenario_identifier.name == "Scenario OpenAI" + + +def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instance: MemoryInterface): + """Test filtering scenario results by TargetIdentifierFilter with endpoint.""" + target_id_1 = ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="Azure", + class_module="test", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + attack_result1 = create_attack_result("conv_1", "Objective 1") + attack_result2 = create_attack_result("conv_2", "Objective 2") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario OpenAI", scenario_version=1), + objective_target_identifier=target_id_1, + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + scenario2 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario Azure", scenario_version=1), + objective_target_identifier=target_id_2, + attack_results={"Attack2": [attack_result2]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1, scenario2]) + + # Filter by endpoint partial match + results = sqlite_instance.get_scenario_results( + objective_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.ENDPOINT, + value_to_match="openai", + partial_match=True, + ), + ) + assert len(results) == 1 + assert results[0].scenario_identifier.name == "Scenario OpenAI" + + +def test_get_scenario_results_by_target_identifier_filter_no_match(sqlite_instance: MemoryInterface): + """Test that TargetIdentifierFilter returns empty when nothing matches.""" + attack_result1 = create_attack_result("conv_1", "Objective 1") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Test Scenario", scenario_version=1), + objective_target_identifier=ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com"}, + ), + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1]) + + results = sqlite_instance.get_scenario_results( + objective_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.HASH, + value_to_match="nonexistent_hash", + partial_match=False, + ), + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 6087af1418..e9945bfc2e 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -13,6 +13,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry +from pyrit.memory.identifier_filters import ScorerIdentifierFilter, ScorerIdentifierProperty from pyrit.models import ( MessagePiece, Score, @@ -227,3 +228,76 @@ async def test_get_seeds_no_filters(sqlite_instance: MemoryInterface): assert len(result) == 2 assert result[0].value == "prompt1" assert result[1].value == "prompt2" + + +def test_get_scores_by_scorer_identifier_filter( + sqlite_instance: MemoryInterface, sample_conversation_entries: Sequence[PromptMemoryEntry], +): + prompt_id = sample_conversation_entries[0].id + sqlite_instance._insert_entries(entries=sample_conversation_entries) + + score_a = Score( + score_value="0.9", + score_value_description="High", + score_type="float_scale", + score_category=["cat_a"], + score_rationale="Rationale A", + score_metadata={}, + scorer_class_identifier=_test_scorer_id("ScorerAlpha"), + message_piece_id=prompt_id, + ) + score_b = Score( + score_value="0.1", + score_value_description="Low", + score_type="float_scale", + score_category=["cat_b"], + score_rationale="Rationale B", + score_metadata={}, + scorer_class_identifier=_test_scorer_id("ScorerBeta"), + message_piece_id=prompt_id, + ) + + sqlite_instance.add_scores_to_memory(scores=[score_a, score_b]) + + # Filter by exact class_name match + results = sqlite_instance.get_scores( + scorer_identifier_filter=ScorerIdentifierFilter( + property_path=ScorerIdentifierProperty.CLASS_NAME, + value_to_match="ScorerAlpha", + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].score_value == "0.9" + + # Filter by partial class_name match + results = sqlite_instance.get_scores( + scorer_identifier_filter=ScorerIdentifierFilter( + property_path=ScorerIdentifierProperty.CLASS_NAME, + value_to_match="Scorer", + partial_match=True, + ), + ) + assert len(results) == 2 + + # Filter by hash + scorer_hash = score_a.scorer_class_identifier.hash + results = sqlite_instance.get_scores( + scorer_identifier_filter=ScorerIdentifierFilter( + property_path=ScorerIdentifierProperty.HASH, + value_to_match=scorer_hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].score_value == "0.9" + + # No match + results = sqlite_instance.get_scores( + scorer_identifier_filter=ScorerIdentifierFilter( + property_path=ScorerIdentifierProperty.CLASS_NAME, + value_to_match="NonExistent", + partial_match=False, + ), + ) + assert len(results) == 0 From 01aaa159e559247699bee95217923722d6955d46 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 1 Apr 2026 13:56:45 -0700 Subject: [PATCH 2/9] forgot formatting --- pyrit/memory/__init__.py | 11 ++++++++++- pyrit/memory/azure_sql_memory.py | 12 +++++------- pyrit/memory/memory_interface.py | 16 ++++++++-------- pyrit/memory/sqlite_memory.py | 8 ++++---- .../memory_interface/test_interface_scores.py | 3 ++- 5 files changed, 29 insertions(+), 21 deletions(-) diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 102a1f8607..a22469de00 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -7,9 +7,18 @@ This package defines the core `MemoryInterface` and concrete implementations for different storage backends. """ -from pyrit.memory.identifier_filters import AttackIdentifierFilter, AttackIdentifierProperty, ConverterIdentifierFilter, ConverterIdentifierProperty, ScorerIdentifierFilter, ScorerIdentifierProperty, TargetIdentifierFilter, TargetIdentifierProperty from pyrit.memory.azure_sql_memory import AzureSQLMemory from pyrit.memory.central_memory import CentralMemory +from pyrit.memory.identifier_filters import ( + AttackIdentifierFilter, + AttackIdentifierProperty, + ConverterIdentifierFilter, + ConverterIdentifierProperty, + ScorerIdentifierFilter, + ScorerIdentifierProperty, + TargetIdentifierFilter, + TargetIdentifierProperty, +) from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_interface import MemoryInterface diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 48ae2c5df2..fc7a951f1e 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -319,10 +319,10 @@ def _get_condition_json_property_match( return text( f"""ISJSON("{table_name}".{column_name}) = 1 AND LOWER(JSON_VALUE("{table_name}".{column_name}, :property_path)) {"LIKE" if partial_match else "="} :match_property_value""" # noqa: E501 - ).bindparams( - property_path=property_path, - match_property_value=f"%{value_to_match.lower()}%" if partial_match else value_to_match, - ) + ).bindparams( + property_path=property_path, + match_property_value=f"%{value_to_match.lower()}%" if partial_match else value_to_match, + ) # The above return statement already handles both partial and exact matches # The following code is now unreachable and can be removed @@ -360,9 +360,7 @@ def _get_condition_json_array_match( bindparams_dict[param_name] = match_value.lower() if case_insensitive else match_value combined = " AND ".join(conditions) - return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams( - **bindparams_dict - ) + return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) def _get_unique_json_property_values( self, diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 5bc1f4ad3e..0fcdfc6f3c 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -374,7 +374,7 @@ def get_unique_attack_class_names(self) -> list[str]: """ return self._get_unique_json_array_values( json_column=AttackResultEntry.atomic_attack_identifier, - path_to_array=AttackIdentifierProperty.ATTACK_CLASS_NAME + path_to_array=AttackIdentifierProperty.ATTACK_CLASS_NAME, ) def get_unique_converter_class_names(self) -> list[str]: @@ -638,7 +638,7 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. - attack_identifier_filter (Optional[AttackIdentifierFilter], optional): + attack_identifier_filter (Optional[AttackIdentifierFilter], optional): An AttackIdentifierFilter object that allows filtering by various attack identifier JSON properties. Defaults to None. prompt_target_identifier_filter (Optional[TargetIdentifierFilter], optional): @@ -658,7 +658,7 @@ def get_message_pieces( self._get_condition_json_property_match( json_column=PromptMemoryEntry.attack_identifier, property_path=AttackIdentifierProperty.HASH, - value_to_match=str(attack_id) + value_to_match=str(attack_id), ) ) if role: @@ -1770,12 +1770,12 @@ def get_scenario_results( # Use database-specific JSON query method conditions.append( self._get_condition_json_property_match( - json_column=ScenarioResultEntry.objective_target_identifier, - property_path=TargetIdentifierProperty.ENDPOINT, - value_to_match=objective_target_endpoint, - partial_match=True, + json_column=ScenarioResultEntry.objective_target_identifier, + property_path=TargetIdentifierProperty.ENDPOINT, + value_to_match=objective_target_endpoint, + partial_match=True, + ) ) - ) if objective_target_model_name: # Use database-specific JSON query method diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index a41dbffc90..3e94e0e2ea 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -234,9 +234,9 @@ def _get_condition_json_array_match( } conditions.append( text( - f'''EXISTS(SELECT 1 FROM json_each( + f"""EXISTS(SELECT 1 FROM json_each( json_extract("{table_name}".{column_name}, :property_path)) - WHERE {value_expression} = :{param_name})''' + WHERE {value_expression} = :{param_name})""" ).bindparams(**bind_params) ) return and_(*conditions) @@ -257,10 +257,10 @@ def _get_unique_json_array_values( column_name = json_column.key rows = session.execute( text( - f'''SELECT DISTINCT json_extract(j.value, :sub_path) AS value + f"""SELECT DISTINCT json_extract(j.value, :sub_path) AS value FROM "{table_name}", json_each(json_extract("{table_name}".{column_name}, :path_to_array)) AS j - WHERE json_extract(j.value, :sub_path) IS NOT NULL''' + WHERE json_extract(j.value, :sub_path) IS NOT NULL""" ).bindparams( path_to_array=path_to_array, sub_path=sub_path, diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index e9945bfc2e..bb9478c3b6 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -231,7 +231,8 @@ async def test_get_seeds_no_filters(sqlite_instance: MemoryInterface): def test_get_scores_by_scorer_identifier_filter( - sqlite_instance: MemoryInterface, sample_conversation_entries: Sequence[PromptMemoryEntry], + sqlite_instance: MemoryInterface, + sample_conversation_entries: Sequence[PromptMemoryEntry], ): prompt_id = sample_conversation_entries[0].id sqlite_instance._insert_entries(entries=sample_conversation_entries) From e77b43c0b604e162242791578df6611f44376a5b Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 1 Apr 2026 14:05:36 -0700 Subject: [PATCH 3/9] return str --- pyrit/memory/identifier_filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index 8792f03241..10aba39aa5 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -12,7 +12,7 @@ class _StrEnum(str, Enum): """Base class that mimics StrEnum behavior for Python < 3.11.""" def __str__(self) -> str: - return self.value + return str(self.value) T = TypeVar("T", bound=_StrEnum) From a06b5060ca25add903bef9055f5435dfe9a05779 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 1 Apr 2026 14:08:55 -0700 Subject: [PATCH 4/9] fix method name --- pyrit/memory/azure_sql_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index fc7a951f1e..cf9c5f6d49 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -362,7 +362,7 @@ def _get_condition_json_array_match( combined = " AND ".join(conditions) return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) - def _get_unique_json_property_values( + def _get_unique_json_array_values( self, *, json_column: Any, From 9d3cb5f378ea1f30d163a798aed57f84875c4964 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Apr 2026 09:39:51 -0700 Subject: [PATCH 5/9] add back public methods --- pyrit/memory/azure_sql_memory.py | 21 +++++++++++++++++++++ pyrit/memory/sqlite_memory.py | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index cf9c5f6d49..8941078e4f 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -571,6 +571,27 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] """ self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces]) + def get_unique_attack_class_names(self) -> list[str]: + """ + Azure SQL implementation: extract unique class_name values from + the atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique attack class name strings. + """ + return super().get_unique_attack_class_names() + + def get_unique_converter_class_names(self) -> list[str]: + """ + Azure SQL implementation: extract unique converter class_name values + from the children.attack.children.request_converters array + in the atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique converter class name strings. + """ + return super().get_unique_converter_class_names() + def dispose_engine(self) -> None: """ Dispose the engine and clean up resources. diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 3e94e0e2ea..f76f300a3d 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -268,6 +268,27 @@ def _get_unique_json_array_values( ).fetchall() return sorted(row[0] for row in rows) + def get_unique_attack_class_names(self) -> list[str]: + """ + SQLite implementation: extract unique class_name values from + the atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique attack class name strings. + """ + return super().get_unique_attack_class_names() + + def get_unique_converter_class_names(self) -> list[str]: + """ + SQLite implementation: extract unique converter class_name values + from the children.attack.children.request_converters array in the + atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique converter class name strings. + """ + return super().get_unique_converter_class_names() + def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. From 5389a9f4c85cabbf987873077ae97da6c2c1b97f Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Apr 2026 11:42:13 -0700 Subject: [PATCH 6/9] custom subpath for array match and make all matches case insensitive --- pyrit/memory/azure_sql_memory.py | 14 +++++--------- pyrit/memory/memory_interface.py | 6 +++--- pyrit/memory/sqlite_memory.py | 14 ++++++-------- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 8941078e4f..5ff9710ceb 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -321,18 +321,16 @@ def _get_condition_json_property_match( AND LOWER(JSON_VALUE("{table_name}".{column_name}, :property_path)) {"LIKE" if partial_match else "="} :match_property_value""" # noqa: E501 ).bindparams( property_path=property_path, - match_property_value=f"%{value_to_match.lower()}%" if partial_match else value_to_match, + match_property_value=f"%{value_to_match.lower()}%", ) - # The above return statement already handles both partial and exact matches - # The following code is now unreachable and can be removed def _get_condition_json_array_match( self, *, json_column: Any, property_path: str, - array_to_match: Sequence[str], - case_insensitive: bool = False, + sub_path: str | None = None, + array_to_match: Sequence[str] ) -> Any: table_name = json_column.class_.__tablename__ column_name = json_column.key @@ -343,9 +341,7 @@ def _get_condition_json_array_match( OR JSON_QUERY("{table_name}".{column_name}, :property_path) = '[]')""" ).bindparams(property_path=property_path) - value_expression = "JSON_VALUE(value, '$.class_name')" - if case_insensitive: - value_expression = f"LOWER({value_expression})" + value_expression = f"LOWER(JSON_VALUE(value, '{sub_path}'))" if sub_path else "LOWER(value)" conditions = [] bindparams_dict: dict[str, str] = {"property_path": property_path} @@ -357,7 +353,7 @@ def _get_condition_json_array_match( :property_path)) WHERE {value_expression} = :{param_name})""" ) - bindparams_dict[param_name] = match_value.lower() if case_insensitive else match_value + bindparams_dict[param_name] = match_value.lower() combined = " AND ".join(conditions) return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 0fcdfc6f3c..74f99f0217 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -149,8 +149,8 @@ def _get_condition_json_array_match( *, json_column: InstrumentedAttribute[Any], property_path: str, + sub_path: Optional[str] = None, array_to_match: Sequence[str], - case_insensitive: bool = False, ) -> Any: """ Return a database-specific condition for matching an array at a given path within a JSON object. @@ -158,10 +158,10 @@ def _get_condition_json_array_match( Args: json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. property_path (str): The JSON path for the target array. + sub_path (Optional[str]): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. For a match, ALL values in this array must be present in the JSON array. If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. - case_insensitive (bool): Whether string comparison should ignore casing. Returns: Any: A database-specific SQLAlchemy condition. @@ -1500,8 +1500,8 @@ def get_attack_results( self._get_condition_json_array_match( json_column=AttackResultEntry.atomic_attack_identifier, property_path=AttackIdentifierProperty.REQUEST_CONVERTERS, + sub_path=ConverterIdentifierProperty.CLASS_NAME, array_to_match=converter_classes, - case_insensitive=True, ) ) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index f76f300a3d..fa9487055e 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -198,18 +198,18 @@ def _get_condition_json_property_match( value_to_match: str, partial_match: bool = False, ) -> Any: - extracted_value = func.json_extract(json_column, property_path) + extracted_value = func.lower(func.json_extract(json_column, property_path)) if partial_match: - return func.lower(extracted_value).like(f"%{value_to_match.lower()}%") - return extracted_value == value_to_match + return extracted_value.like(f"%{value_to_match.lower()}%") + return extracted_value == value_to_match.lower() def _get_condition_json_array_match( self, *, json_column: Any, property_path: str, + sub_path: str | None = None, array_to_match: Sequence[str], - case_insensitive: bool = False, ) -> Any: array_expr = func.json_extract(json_column, property_path) if len(array_to_match) == 0: @@ -221,16 +221,14 @@ def _get_condition_json_array_match( table_name = json_column.class_.__tablename__ column_name = json_column.key - value_expression = "json_extract(value, '$.class_name')" - if case_insensitive: - value_expression = f"LOWER({value_expression})" + value_expression = f"LOWER(json_extract(value, '{sub_path}'))" if sub_path else "LOWER(value)" conditions = [] for index, match_value in enumerate(array_to_match): param_name = f"match_value_{index}" bind_params: dict[str, str] = { "property_path": property_path, - param_name: match_value.lower() if case_insensitive else match_value, + param_name: match_value.lower(), } conditions.append( text( From 3fa071367a9a25486ec842fcc419ab3c4fa58027 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Apr 2026 11:47:45 -0700 Subject: [PATCH 7/9] format --- pyrit/memory/azure_sql_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 5ff9710ceb..916f64508d 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -330,7 +330,7 @@ def _get_condition_json_array_match( json_column: Any, property_path: str, sub_path: str | None = None, - array_to_match: Sequence[str] + array_to_match: Sequence[str], ) -> Any: table_name = json_column.class_.__tablename__ column_name = json_column.key From 24f61d1ecb73b3c171c29072a497ef9c67981ab6 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Apr 2026 14:26:39 -0700 Subject: [PATCH 8/9] allow free-form paths in identifier filters --- pyrit/memory/__init__.py | 20 +---- pyrit/memory/identifier_filters.py | 82 +------------------ pyrit/memory/memory_interface.py | 55 ++++++------- .../test_interface_attack_results.py | 23 ++---- .../test_interface_prompts.py | 27 +++--- .../test_interface_scenario_results.py | 18 ++-- .../memory_interface/test_interface_scores.py | 18 ++-- 7 files changed, 64 insertions(+), 179 deletions(-) diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index a22469de00..6098122d7d 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -9,16 +9,7 @@ from pyrit.memory.azure_sql_memory import AzureSQLMemory from pyrit.memory.central_memory import CentralMemory -from pyrit.memory.identifier_filters import ( - AttackIdentifierFilter, - AttackIdentifierProperty, - ConverterIdentifierFilter, - ConverterIdentifierProperty, - ScorerIdentifierFilter, - ScorerIdentifierProperty, - TargetIdentifierFilter, - TargetIdentifierProperty, -) +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_interface import MemoryInterface @@ -27,10 +18,6 @@ __all__ = [ "AttackResultEntry", - "AttackIdentifierFilter", - "AttackIdentifierProperty", - "ConverterIdentifierFilter", - "ConverterIdentifierProperty", "AzureSQLMemory", "CentralMemory", "SQLiteMemory", @@ -39,9 +26,6 @@ "MemoryEmbedding", "MemoryExporter", "PromptMemoryEntry", - "ScorerIdentifierFilter", - "ScorerIdentifierProperty", "SeedEntry", - "TargetIdentifierFilter", - "TargetIdentifierProperty", + "IdentifierFilter", ] diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index 10aba39aa5..74c62c877a 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -1,95 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from abc import ABC from dataclasses import dataclass -from enum import Enum -from typing import Generic, TypeVar - - -# TODO: if/when we move to python 3.11+, we can replace this with StrEnum -class _StrEnum(str, Enum): - """Base class that mimics StrEnum behavior for Python < 3.11.""" - - def __str__(self) -> str: - return str(self.value) - - -T = TypeVar("T", bound=_StrEnum) - - -class IdentifierProperty(_StrEnum): - """Allowed JSON paths for identifier filtering.""" - - HASH = "$.hash" - CLASS_NAME = "$.class_name" @dataclass(frozen=True) -class IdentifierFilter(ABC, Generic[T]): +class IdentifierFilter: """Immutable filter definition for matching JSON-backed identifier properties.""" - property_path: T | str + property_path: str value_to_match: str partial_match: bool = False def __post_init__(self) -> None: """Normalize and validate the configured property path.""" object.__setattr__(self, "property_path", str(self.property_path)) - - -class AttackIdentifierProperty(_StrEnum): - """Allowed JSON paths for attack identifier filtering.""" - - HASH = "$.hash" - ATTACK_CLASS_NAME = "$.children.attack.class_name" - REQUEST_CONVERTERS = "$.children.attack.children.request_converters" - - -class TargetIdentifierProperty(_StrEnum): - """Allowed JSON paths for target identifier filtering.""" - - HASH = "$.hash" - ENDPOINT = "$.endpoint" - MODEL_NAME = "$.model_name" - - -class ConverterIdentifierProperty(_StrEnum): - """Allowed JSON paths for converter identifier filtering.""" - - HASH = "$.hash" - CLASS_NAME = "$.class_name" - - -class ScorerIdentifierProperty(_StrEnum): - """Allowed JSON paths for scorer identifier filtering.""" - - HASH = "$.hash" - CLASS_NAME = "$.class_name" - - -@dataclass(frozen=True) -class AttackIdentifierFilter(IdentifierFilter[AttackIdentifierProperty]): - """ - Immutable filter definition for matching JSON-backed attack identifier properties. - - Args: - property_path: The JSON path of the property to filter on. - value_to_match: The value to match against the property. - partial_match: Whether to allow partial matches (default: False). - """ - - -@dataclass(frozen=True) -class TargetIdentifierFilter(IdentifierFilter[TargetIdentifierProperty]): - """Immutable filter definition for matching JSON-backed target identifier properties.""" - - -@dataclass(frozen=True) -class ConverterIdentifierFilter(IdentifierFilter[ConverterIdentifierProperty]): - """Immutable filter definition for matching JSON-backed converter identifier properties.""" - - -@dataclass(frozen=True) -class ScorerIdentifierFilter(IdentifierFilter[ScorerIdentifierProperty]): - """Immutable filter definition for matching JSON-backed scorer identifier properties.""" diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 74f99f0217..1ef99789d8 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -19,14 +19,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from pyrit.common.path import DB_DATA_PATH -from pyrit.memory.identifier_filters import ( - AttackIdentifierFilter, - AttackIdentifierProperty, - ConverterIdentifierProperty, - ScorerIdentifierFilter, - TargetIdentifierFilter, - TargetIdentifierProperty, -) +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.memory.memory_embedding import ( MemoryEmbedding, default_memory_embedding_factory, @@ -374,7 +367,7 @@ def get_unique_attack_class_names(self) -> list[str]: """ return self._get_unique_json_array_values( json_column=AttackResultEntry.atomic_attack_identifier, - path_to_array=AttackIdentifierProperty.ATTACK_CLASS_NAME, + path_to_array="$.children.attack.class_name", ) def get_unique_converter_class_names(self) -> list[str]: @@ -389,8 +382,8 @@ def get_unique_converter_class_names(self) -> list[str]: """ return self._get_unique_json_array_values( json_column=AttackResultEntry.atomic_attack_identifier, - path_to_array=AttackIdentifierProperty.REQUEST_CONVERTERS, - sub_path=ConverterIdentifierProperty.CLASS_NAME, + path_to_array="$.children.attack.children.request_converters", + sub_path="$.class_name", ) @abc.abstractmethod @@ -447,7 +440,7 @@ def get_scores( score_category: Optional[str] = None, sent_after: Optional[datetime] = None, sent_before: Optional[datetime] = None, - scorer_identifier_filter: Optional[ScorerIdentifierFilter] = None, + scorer_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[Score]: """ Retrieve a list of Score objects based on the specified filters. @@ -458,7 +451,7 @@ def get_scores( score_category (Optional[str]): The category of the score to filter by. sent_after (Optional[datetime]): Filter for scores sent after this datetime. sent_before (Optional[datetime]): Filter for scores sent before this datetime. - scorer_identifier_filter (Optional[ScorerIdentifierFilter]): A ScorerIdentifierFilter object that + scorer_identifier_filter (Optional[IdentifierFilter]): An IdentifierFilter object that allows filtering by various scorer identifier JSON properties. Defaults to None. Returns: @@ -615,8 +608,8 @@ def get_message_pieces( data_type: Optional[str] = None, not_data_type: Optional[str] = None, converted_value_sha256: Optional[Sequence[str]] = None, - attack_identifier_filter: Optional[AttackIdentifierFilter] = None, - prompt_target_identifier_filter: Optional[TargetIdentifierFilter] = None, + attack_identifier_filter: Optional[IdentifierFilter] = None, + prompt_target_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[MessagePiece]: """ Retrieve a list of MessagePiece objects based on the specified filters. @@ -638,11 +631,11 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. - attack_identifier_filter (Optional[AttackIdentifierFilter], optional): - An AttackIdentifierFilter object that + attack_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that allows filtering by various attack identifier JSON properties. Defaults to None. - prompt_target_identifier_filter (Optional[TargetIdentifierFilter], optional): - A TargetIdentifierFilter object that + prompt_target_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that allows filtering by various target identifier JSON properties. Defaults to None. Returns: @@ -657,7 +650,7 @@ def get_message_pieces( conditions.append( self._get_condition_json_property_match( json_column=PromptMemoryEntry.attack_identifier, - property_path=AttackIdentifierProperty.HASH, + property_path="$.hash", value_to_match=str(attack_id), ) ) @@ -1431,7 +1424,7 @@ def get_attack_results( converter_classes: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, - attack_identifier_filter: Optional[AttackIdentifierFilter] = None, + attack_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. @@ -1459,8 +1452,8 @@ def get_attack_results( labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. - attack_identifier_filter (Optional[AttackIdentifierFilter], optional): - An AttackIdentifierFilter object that allows filtering by various attack identifier + attack_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that allows filtering by various attack identifier JSON properties. Defaults to None. Returns: @@ -1488,7 +1481,7 @@ def get_attack_results( conditions.append( self._get_condition_json_property_match( json_column=AttackResultEntry.atomic_attack_identifier, - property_path=AttackIdentifierProperty.ATTACK_CLASS_NAME, + property_path="$.children.attack.class_name", value_to_match=attack_class, ) ) @@ -1499,8 +1492,8 @@ def get_attack_results( conditions.append( self._get_condition_json_array_match( json_column=AttackResultEntry.atomic_attack_identifier, - property_path=AttackIdentifierProperty.REQUEST_CONVERTERS, - sub_path=ConverterIdentifierProperty.CLASS_NAME, + property_path="$.children.attack.children.request_converters", + sub_path="$.class_name", array_to_match=converter_classes, ) ) @@ -1705,7 +1698,7 @@ def get_scenario_results( labels: Optional[dict[str, str]] = None, objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, - objective_target_identifier_filter: Optional[TargetIdentifierFilter] = None, + objective_target_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. @@ -1729,8 +1722,8 @@ def get_scenario_results( objective_target_model_name (Optional[str], optional): Filter for scenarios where the objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. - objective_target_identifier_filter (Optional[TargetIdentifierFilter], optional): - A TargetIdentifierFilter object that allows filtering by various target identifier JSON properties. + objective_target_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that allows filtering by various target identifier JSON properties. Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. @@ -1771,7 +1764,7 @@ def get_scenario_results( conditions.append( self._get_condition_json_property_match( json_column=ScenarioResultEntry.objective_target_identifier, - property_path=TargetIdentifierProperty.ENDPOINT, + property_path="$.endpoint", value_to_match=objective_target_endpoint, partial_match=True, ) @@ -1782,7 +1775,7 @@ def get_scenario_results( conditions.append( self._get_condition_json_property_match( json_column=ScenarioResultEntry.objective_target_identifier, - property_path=TargetIdentifierProperty.MODEL_NAME, + property_path="$.model_name", value_to_match=objective_target_model_name, partial_match=True, ) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index de238952f4..84cda0b409 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -9,7 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory import MemoryInterface -from pyrit.memory.identifier_filters import AttackIdentifierFilter, AttackIdentifierProperty +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( AttackOutcome, @@ -1176,15 +1176,6 @@ def test_get_attack_results_by_attack_class_no_match(sqlite_instance: MemoryInte assert len(results) == 0 -def test_get_attack_results_by_attack_class_case_sensitive(sqlite_instance: MemoryInterface): - """Test that attack_class filter is case-sensitive (exact match).""" - ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") - sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - - results = sqlite_instance.get_attack_results(attack_class="crescendoattack") - assert len(results) == 0 - - def test_get_attack_results_by_attack_class_no_identifier(sqlite_instance: MemoryInterface): """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_class filter.""" ar1 = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} @@ -1363,8 +1354,8 @@ def test_get_attack_results_by_attack_identifier_filter_hash(sqlite_instance: Me # Filter by hash of ar1's attack identifier results = sqlite_instance.get_attack_results( - attack_identifier_filter=AttackIdentifierFilter( - property_path=AttackIdentifierProperty.HASH, + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match=ar1.atomic_attack_identifier.hash, partial_match=False, ), @@ -1382,8 +1373,8 @@ def test_get_attack_results_by_attack_identifier_filter_class_name(sqlite_instan # Filter by partial attack class name results = sqlite_instance.get_attack_results( - attack_identifier_filter=AttackIdentifierFilter( - property_path=AttackIdentifierProperty.ATTACK_CLASS_NAME, + attack_identifier_filter=IdentifierFilter( + property_path="$.children.attack.class_name", value_to_match="Crescendo", partial_match=True, ), @@ -1398,8 +1389,8 @@ def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) results = sqlite_instance.get_attack_results( - attack_identifier_filter=AttackIdentifierFilter( - property_path=AttackIdentifierProperty.HASH, + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match="nonexistent_hash", partial_match=False, ), diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 457169b911..eec4d3d88a 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -14,12 +14,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry -from pyrit.memory.identifier_filters import ( - AttackIdentifierFilter, - AttackIdentifierProperty, - TargetIdentifierFilter, - TargetIdentifierProperty, -) +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.models import ( Message, MessagePiece, @@ -1281,8 +1276,8 @@ def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryI # Filter by exact attack hash results = sqlite_instance.get_message_pieces( - attack_identifier_filter=AttackIdentifierFilter( - property_path=AttackIdentifierProperty.HASH, + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match=attack1.get_identifier().hash, partial_match=False, ), @@ -1292,8 +1287,8 @@ def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryI # No match results = sqlite_instance.get_message_pieces( - attack_identifier_filter=AttackIdentifierFilter( - property_path=AttackIdentifierProperty.HASH, + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match="nonexistent_hash", partial_match=False, ), @@ -1334,8 +1329,8 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI # Filter by target hash results = sqlite_instance.get_message_pieces( - prompt_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.HASH, + prompt_target_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match=target_id_1.hash, partial_match=False, ), @@ -1345,8 +1340,8 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI # Filter by endpoint partial match results = sqlite_instance.get_message_pieces( - prompt_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.ENDPOINT, + prompt_target_identifier_filter=IdentifierFilter( + property_path="$.endpoint", value_to_match="openai", partial_match=True, ), @@ -1356,8 +1351,8 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI # No match results = sqlite_instance.get_message_pieces( - prompt_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.HASH, + prompt_target_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match="nonexistent", partial_match=False, ), diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 51b64a819b..ee2933b70a 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -9,7 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface -from pyrit.memory.identifier_filters import TargetIdentifierFilter, TargetIdentifierProperty +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.models import ( AttackOutcome, AttackResult, @@ -649,7 +649,7 @@ def test_combined_filters(sqlite_instance: MemoryInterface): def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: MemoryInterface): - """Test filtering scenario results by TargetIdentifierFilter with hash.""" + """Test filtering scenario results by identifier filter.""" target_id_1 = ComponentIdentifier( class_name="OpenAI", class_module="test", @@ -681,8 +681,8 @@ def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: # Filter by target hash results = sqlite_instance.get_scenario_results( - objective_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.HASH, + objective_target_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match=target_id_1.hash, partial_match=False, ), @@ -692,7 +692,7 @@ def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instance: MemoryInterface): - """Test filtering scenario results by TargetIdentifierFilter with endpoint.""" + """Test filtering scenario results by identifier filter with endpoint.""" target_id_1 = ComponentIdentifier( class_name="OpenAI", class_module="test", @@ -724,8 +724,8 @@ def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instan # Filter by endpoint partial match results = sqlite_instance.get_scenario_results( - objective_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.ENDPOINT, + objective_target_identifier_filter=IdentifierFilter( + property_path="$.endpoint", value_to_match="openai", partial_match=True, ), @@ -752,8 +752,8 @@ def test_get_scenario_results_by_target_identifier_filter_no_match(sqlite_instan sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1]) results = sqlite_instance.get_scenario_results( - objective_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.HASH, + objective_target_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match="nonexistent_hash", partial_match=False, ), diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index bb9478c3b6..2c90b18313 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -13,7 +13,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry -from pyrit.memory.identifier_filters import ScorerIdentifierFilter, ScorerIdentifierProperty +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.models import ( MessagePiece, Score, @@ -262,8 +262,8 @@ def test_get_scores_by_scorer_identifier_filter( # Filter by exact class_name match results = sqlite_instance.get_scores( - scorer_identifier_filter=ScorerIdentifierFilter( - property_path=ScorerIdentifierProperty.CLASS_NAME, + scorer_identifier_filter=IdentifierFilter( + property_path="$.class_name", value_to_match="ScorerAlpha", partial_match=False, ), @@ -273,8 +273,8 @@ def test_get_scores_by_scorer_identifier_filter( # Filter by partial class_name match results = sqlite_instance.get_scores( - scorer_identifier_filter=ScorerIdentifierFilter( - property_path=ScorerIdentifierProperty.CLASS_NAME, + scorer_identifier_filter=IdentifierFilter( + property_path="$.class_name", value_to_match="Scorer", partial_match=True, ), @@ -284,8 +284,8 @@ def test_get_scores_by_scorer_identifier_filter( # Filter by hash scorer_hash = score_a.scorer_class_identifier.hash results = sqlite_instance.get_scores( - scorer_identifier_filter=ScorerIdentifierFilter( - property_path=ScorerIdentifierProperty.HASH, + scorer_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match=scorer_hash, partial_match=False, ), @@ -295,8 +295,8 @@ def test_get_scores_by_scorer_identifier_filter( # No match results = sqlite_instance.get_scores( - scorer_identifier_filter=ScorerIdentifierFilter( - property_path=ScorerIdentifierProperty.CLASS_NAME, + scorer_identifier_filter=IdentifierFilter( + property_path="$.class_name", value_to_match="NonExistent", partial_match=False, ), From 39361af24d7ce24f0740fed0aaff6eed811ecea2 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Apr 2026 14:54:41 -0700 Subject: [PATCH 9/9] unncecessary post-init --- pyrit/memory/identifier_filters.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index 74c62c877a..122d89965b 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -11,7 +11,3 @@ class IdentifierFilter: property_path: str value_to_match: str partial_match: bool = False - - def __post_init__(self) -> None: - """Normalize and validate the configured property path.""" - object.__setattr__(self, "property_path", str(self.property_path))