diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 9f10860130..6098122d7d 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -9,6 +9,7 @@ from pyrit.memory.azure_sql_memory import AzureSQLMemory from pyrit.memory.central_memory import CentralMemory +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_interface import MemoryInterface @@ -26,4 +27,5 @@ "MemoryExporter", "PromptMemoryEntry", "SeedEntry", + "IdentifierFilter", ] diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index c9f349c0d9..916f64508d 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -250,22 +250,6 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()}) return [condition] - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Generate SQL condition for filtering message pieces by attack ID. - - Uses JSON_VALUE() function specific to SQL Azure to query the attack identifier. - - Args: - attack_id (str): The attack identifier to filter by. - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.hash') = :json_id").bindparams( - json_id=str(attack_id) - ) - def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]: """ Generate SQL conditions for filtering by prompt metadata. @@ -321,6 +305,93 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) """ return self._get_metadata_conditions(prompt_metadata=metadata)[0] + def _get_condition_json_property_match( + self, + *, + json_column: Any, + property_path: str, + value_to_match: str, + partial_match: bool = False, + ) -> Any: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + + return text( + f"""ISJSON("{table_name}".{column_name}) = 1 + AND LOWER(JSON_VALUE("{table_name}".{column_name}, :property_path)) {"LIKE" if partial_match else "="} :match_property_value""" # noqa: E501 + ).bindparams( + property_path=property_path, + match_property_value=f"%{value_to_match.lower()}%", + ) + + def _get_condition_json_array_match( + self, + *, + json_column: Any, + property_path: str, + sub_path: str | None = None, + array_to_match: Sequence[str], + ) -> Any: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + if len(array_to_match) == 0: + return text( + f"""("{table_name}".{column_name} IS NULL + OR JSON_QUERY("{table_name}".{column_name}, :property_path) IS NULL + OR JSON_QUERY("{table_name}".{column_name}, :property_path) = '[]')""" + ).bindparams(property_path=property_path) + + value_expression = f"LOWER(JSON_VALUE(value, '{sub_path}'))" if sub_path else "LOWER(value)" + + conditions = [] + bindparams_dict: dict[str, str] = {"property_path": property_path} + + for index, match_value in enumerate(array_to_match): + param_name = f"match_value_{index}" + conditions.append( + f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("{table_name}".{column_name}, + :property_path)) + WHERE {value_expression} = :{param_name})""" + ) + bindparams_dict[param_name] = match_value.lower() + + combined = " AND ".join(conditions) + return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) + + def _get_unique_json_array_values( + self, + *, + json_column: Any, + path_to_array: str, + sub_path: str | None = None, + ) -> list[str]: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + with closing(self.get_session()) as session: + if sub_path is None: + rows = session.execute( + text( + f"""SELECT DISTINCT JSON_VALUE("{table_name}".{column_name}, :path_to_array) AS value + FROM "{table_name}" + WHERE ISJSON("{table_name}".{column_name}) = 1 + AND JSON_VALUE("{table_name}".{column_name}, :path_to_array) IS NOT NULL""" + ).bindparams(path_to_array=path_to_array) + ).fetchall() + else: + rows = session.execute( + text( + f"""SELECT DISTINCT JSON_VALUE(items.value, :sub_path) AS value + FROM "{table_name}" + CROSS APPLY OPENJSON(JSON_QUERY("{table_name}".{column_name}, :path_to_array)) AS items + WHERE ISJSON("{table_name}".{column_name}) = 1 + AND JSON_VALUE(items.value, :sub_path) IS NOT NULL""" + ).bindparams( + path_to_array=path_to_array, + sub_path=sub_path, + ) + ).fetchall() + return sorted(row[0] for row in rows) + def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: """ Get the SQL Azure implementation for filtering AttackResults by targeted harm categories. @@ -388,110 +459,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) ) - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - Azure SQL implementation for filtering AttackResults by attack class. - Uses JSON_VALUE() on the atomic_attack_identifier JSON column. - - Args: - attack_class (str): Exact attack class name to match. - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text( - """ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 - AND JSON_VALUE("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.class_name') = :attack_class""" - ).bindparams(attack_class=attack_class) - - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - Azure SQL implementation for filtering AttackResults by converter classes. - - Uses JSON_VALUE()/JSON_QUERY()/OPENJSON() on the atomic_attack_identifier - JSON column. - - When converter_classes is empty, matches attacks with no converters. - When non-empty, uses OPENJSON() to check all specified classes are present - (AND logic, case-insensitive). - - Args: - converter_classes (Sequence[str]): List of converter class names. Empty list means no converters. - - Returns: - Any: SQLAlchemy combined condition with bound parameters. - """ - if len(converter_classes) == 0: - # Explicitly "no converters": match attacks where the converter list - # is absent, null, or empty in the stored JSON. - return text( - """("AttackResultEntries".atomic_attack_identifier IS NULL - OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') IS NULL - OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') = '[]')""" - ) - - conditions = [] - bindparams_dict: dict[str, str] = {} - for i, cls in enumerate(converter_classes): - param_name = f"conv_cls_{i}" - conditions.append( - f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters')) - WHERE LOWER(JSON_VALUE(value, '$.class_name')) = :{param_name})""" - ) - bindparams_dict[param_name] = cls.lower() - - combined = " AND ".join(conditions) - return text(f"""ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 AND {combined}""").bindparams( - **bindparams_dict - ) - - def get_unique_attack_class_names(self) -> list[str]: - """ - Azure SQL implementation: extract unique class_name values from - the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique attack class name strings. - """ - with closing(self.get_session()) as session: - rows = session.execute( - text( - """SELECT DISTINCT JSON_VALUE(atomic_attack_identifier, - '$.children.attack.class_name') AS cls - FROM "AttackResultEntries" - WHERE ISJSON(atomic_attack_identifier) = 1 - AND JSON_VALUE(atomic_attack_identifier, - '$.children.attack.class_name') IS NOT NULL""" - ) - ).fetchall() - return sorted(row[0] for row in rows) - - def get_unique_converter_class_names(self) -> list[str]: - """ - Azure SQL implementation: extract unique converter class_name values - from the children.attack.children.request_converters array - in the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique converter class name strings. - """ - with closing(self.get_session()) as session: - rows = session.execute( - text( - """SELECT DISTINCT JSON_VALUE(c.value, '$.class_name') AS cls - FROM "AttackResultEntries" - CROSS APPLY OPENJSON(JSON_QUERY(atomic_attack_identifier, - '$.children.attack.children.request_converters')) AS c - WHERE ISJSON(atomic_attack_identifier) = 1 - AND JSON_VALUE(c.value, '$.class_name') IS NOT NULL""" - ) - ).fetchall() - return sorted(row[0] for row in rows) - def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]: """ Azure SQL implementation: lightweight aggregate stats per conversation. @@ -593,46 +560,33 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any conditions.append(condition) return and_(*conditions) - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> TextClause: + def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ - Get the SQL Azure implementation for filtering ScenarioResults by target endpoint. - - Uses JSON_VALUE() function specific to SQL Azure. - - Args: - endpoint (str): The endpoint URL substring to filter by (case-insensitive). + Insert a list of message pieces into the memory storage. - Returns: - Any: SQLAlchemy text condition with bound parameter. """ - return text( - """ISJSON(objective_target_identifier) = 1 - AND LOWER(JSON_VALUE(objective_target_identifier, '$.endpoint')) LIKE :endpoint""" - ).bindparams(endpoint=f"%{endpoint.lower()}%") + self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces]) - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> TextClause: + def get_unique_attack_class_names(self) -> list[str]: """ - Get the SQL Azure implementation for filtering ScenarioResults by target model name. - - Uses JSON_VALUE() function specific to SQL Azure. - - Args: - model_name (str): The model name substring to filter by (case-insensitive). + Azure SQL implementation: extract unique class_name values from + the atomic_attack_identifier JSON column. Returns: - Any: SQLAlchemy text condition with bound parameter. + Sorted list of unique attack class name strings. """ - return text( - """ISJSON(objective_target_identifier) = 1 - AND LOWER(JSON_VALUE(objective_target_identifier, '$.model_name')) LIKE :model_name""" - ).bindparams(model_name=f"%{model_name.lower()}%") + return super().get_unique_attack_class_names() - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def get_unique_converter_class_names(self) -> list[str]: """ - Insert a list of message pieces into the memory storage. + Azure SQL implementation: extract unique converter class_name values + from the children.attack.children.request_converters array + in the atomic_attack_identifier JSON column. + Returns: + Sorted list of unique converter class name strings. """ - self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces]) + return super().get_unique_converter_class_names() def dispose_engine(self) -> None: """ diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py new file mode 100644 index 0000000000..122d89965b --- /dev/null +++ b/pyrit/memory/identifier_filters.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class IdentifierFilter: + """Immutable filter definition for matching JSON-backed identifier properties.""" + + property_path: str + value_to_match: str + partial_match: bool = False diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 90322ebec4..1ef99789d8 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -19,6 +19,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from pyrit.common.path import DB_DATA_PATH +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.memory.memory_embedding import ( MemoryEmbedding, default_memory_embedding_factory, @@ -113,6 +114,77 @@ def disable_embedding(self) -> None: """ self.memory_embedding = None + @abc.abstractmethod + def _get_condition_json_property_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + value_to_match: str, + partial_match: bool = False, + ) -> Any: + """ + Return a database-specific condition for matching a value at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. + property_path (str): The JSON path for the property to match. + value_to_match (str): The string value that must match the extracted JSON property value. + partial_match (bool): Whether to perform a case-insensitive substring match. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + """ + + @abc.abstractmethod + def _get_condition_json_array_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + sub_path: Optional[str] = None, + array_to_match: Sequence[str], + ) -> Any: + """ + Return a database-specific condition for matching an array at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. + property_path (str): The JSON path for the target array. + sub_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_to_match (Sequence[str]): The array that must match the extracted JSON array values. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. + + Returns: + Any: A database-specific SQLAlchemy condition. + """ + + @abc.abstractmethod + def _get_unique_json_array_values( + self, + *, + json_column: Any, + path_to_array: str, + sub_path: str | None = None, + ) -> list[str]: + """ + Return sorted unique values in an array located at a given path within a JSON object. + + This method performs a database-level query to extract distinct values from a + an array within a JSON-type column. When ``sub_path`` is provided, the distinct values are + extracted from each array item using the sub-path. + + Args: + json_column (Any): The JSON-backed model field to query. + path_to_array (str): The JSON path to the array whose unique values are extracted. + sub_path (str | None): Optional JSON path applied to each array + item before collecting distinct values. + + Returns: + list[str]: A sorted list of unique values in the array. + """ + @abc.abstractmethod def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: """ @@ -155,12 +227,6 @@ def _get_message_pieces_prompt_metadata_conditions( list: A list of conditions for filtering memory entries based on prompt metadata. """ - @abc.abstractmethod - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Return a condition to retrieve based on attack ID. - """ - @abc.abstractmethod def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: """ @@ -289,41 +355,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: Database-specific SQLAlchemy condition. """ - @abc.abstractmethod - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - Return a database-specific condition for filtering AttackResults by attack class - (class_name in the attack_identifier JSON column). - - Args: - attack_class: Exact attack class name to match. - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - Return a database-specific condition for filtering AttackResults by converter classes - in the request_converter_identifiers array within attack_identifier JSON column. - - This method is only called when converter filtering is requested (converter_classes - is not None). The caller handles the None-vs-list distinction: - - - ``len(converter_classes) == 0``: return a condition matching attacks with NO converters. - - ``len(converter_classes) > 0``: return a condition requiring ALL specified converter - class names to be present (AND logic, case-insensitive). - - Args: - converter_classes: Converter class names to require. An empty sequence means - "match only attacks that have no converters". - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod def get_unique_attack_class_names(self) -> list[str]: """ Return sorted unique attack class names from all stored attack results. @@ -334,8 +365,11 @@ def get_unique_attack_class_names(self) -> list[str]: Returns: Sorted list of unique attack class name strings. """ + return self._get_unique_json_array_values( + json_column=AttackResultEntry.atomic_attack_identifier, + path_to_array="$.children.attack.class_name", + ) - @abc.abstractmethod def get_unique_converter_class_names(self) -> list[str]: """ Return sorted unique converter class names used across all attack results. @@ -346,6 +380,11 @@ def get_unique_converter_class_names(self) -> list[str]: Returns: Sorted list of unique converter class name strings. """ + return self._get_unique_json_array_values( + json_column=AttackResultEntry.atomic_attack_identifier, + path_to_array="$.children.attack.children.request_converters", + sub_path="$.class_name", + ) @abc.abstractmethod def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, "ConversationStats"]: @@ -377,30 +416,6 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any Database-specific SQLAlchemy condition. """ - @abc.abstractmethod - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any: - """ - Return a database-specific condition for filtering ScenarioResults by target endpoint. - - Args: - endpoint: Endpoint substring to search for (case-insensitive). - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any: - """ - Return a database-specific condition for filtering ScenarioResults by target model name. - - Args: - model_name: Model name substring to search for (case-insensitive). - - Returns: - Database-specific SQLAlchemy condition. - """ - def add_scores_to_memory(self, *, scores: Sequence[Score]) -> None: """ Insert a list of scores into the memory storage. @@ -425,6 +440,7 @@ def get_scores( score_category: Optional[str] = None, sent_after: Optional[datetime] = None, sent_before: Optional[datetime] = None, + scorer_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[Score]: """ Retrieve a list of Score objects based on the specified filters. @@ -435,6 +451,8 @@ def get_scores( score_category (Optional[str]): The category of the score to filter by. sent_after (Optional[datetime]): Filter for scores sent after this datetime. sent_before (Optional[datetime]): Filter for scores sent before this datetime. + scorer_identifier_filter (Optional[IdentifierFilter]): An IdentifierFilter object that + allows filtering by various scorer identifier JSON properties. Defaults to None. Returns: Sequence[Score]: A list of Score objects that match the specified filters. @@ -451,6 +469,15 @@ def get_scores( conditions.append(ScoreEntry.timestamp >= sent_after) if sent_before: conditions.append(ScoreEntry.timestamp <= sent_before) + if scorer_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=ScoreEntry.scorer_class_identifier, + property_path=scorer_identifier_filter.property_path, + value_to_match=scorer_identifier_filter.value_to_match, + partial_match=scorer_identifier_filter.partial_match, + ) + ) if not conditions: return [] @@ -581,6 +608,8 @@ def get_message_pieces( data_type: Optional[str] = None, not_data_type: Optional[str] = None, converted_value_sha256: Optional[Sequence[str]] = None, + attack_identifier_filter: Optional[IdentifierFilter] = None, + prompt_target_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[MessagePiece]: """ Retrieve a list of MessagePiece objects based on the specified filters. @@ -602,6 +631,12 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. + attack_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that + allows filtering by various attack identifier JSON properties. Defaults to None. + prompt_target_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that + allows filtering by various target identifier JSON properties. Defaults to None. Returns: Sequence[MessagePiece]: A list of MessagePiece objects that match the specified filters. @@ -612,7 +647,13 @@ def get_message_pieces( """ conditions = [] if attack_id: - conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.attack_identifier, + property_path="$.hash", + value_to_match=str(attack_id), + ) + ) if role: conditions.append(PromptMemoryEntry.role == role) if conversation_id: @@ -638,6 +679,24 @@ def get_message_pieces( conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) if converted_value_sha256: conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) + if attack_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.attack_identifier, + property_path=attack_identifier_filter.property_path, + value_to_match=attack_identifier_filter.value_to_match, + partial_match=attack_identifier_filter.partial_match, + ) + ) + if prompt_target_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.prompt_target_identifier, + property_path=prompt_target_identifier_filter.property_path, + value_to_match=prompt_target_identifier_filter.value_to_match, + partial_match=prompt_target_identifier_filter.partial_match, + ) + ) try: memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( @@ -1365,6 +1424,7 @@ def get_attack_results( converter_classes: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, + attack_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. @@ -1392,6 +1452,9 @@ def get_attack_results( labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. + attack_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that allows filtering by various attack identifier + JSON properties. Defaults to None. Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. @@ -1415,12 +1478,25 @@ def get_attack_results( if attack_class: # Use database-specific JSON query method - conditions.append(self._get_attack_result_attack_class_condition(attack_class=attack_class)) + conditions.append( + self._get_condition_json_property_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path="$.children.attack.class_name", + value_to_match=attack_class, + ) + ) if converter_classes is not None: # converter_classes=[] means "only attacks with no converters" # converter_classes=["A","B"] means "must have all listed converters" - conditions.append(self._get_attack_result_converter_classes_condition(converter_classes=converter_classes)) + conditions.append( + self._get_condition_json_array_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path="$.children.attack.children.request_converters", + sub_path="$.class_name", + array_to_match=converter_classes, + ) + ) if targeted_harm_categories: # Use database-specific JSON query method @@ -1432,6 +1508,16 @@ def get_attack_results( # Use database-specific JSON query method conditions.append(self._get_attack_result_label_condition(labels=labels)) + if attack_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path=attack_identifier_filter.property_path, + value_to_match=attack_identifier_filter.value_to_match, + partial_match=attack_identifier_filter.partial_match, + ) + ) + try: entries: Sequence[AttackResultEntry] = self._query_entries( AttackResultEntry, conditions=and_(*conditions) if conditions else None @@ -1612,6 +1698,7 @@ def get_scenario_results( labels: Optional[dict[str, str]] = None, objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, + objective_target_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. @@ -1635,6 +1722,8 @@ def get_scenario_results( objective_target_model_name (Optional[str], optional): Filter for scenarios where the objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. + objective_target_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that allows filtering by various target identifier JSON properties. Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. @@ -1672,11 +1761,35 @@ def get_scenario_results( if objective_target_endpoint: # Use database-specific JSON query method - conditions.append(self._get_scenario_result_target_endpoint_condition(endpoint=objective_target_endpoint)) + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path="$.endpoint", + value_to_match=objective_target_endpoint, + partial_match=True, + ) + ) if objective_target_model_name: # Use database-specific JSON query method - conditions.append(self._get_scenario_result_target_model_condition(model_name=objective_target_model_name)) + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path="$.model_name", + value_to_match=objective_target_model_name, + partial_match=True, + ) + ) + + if objective_target_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path=objective_target_identifier_filter.property_path, + value_to_match=objective_target_identifier_filter.value_to_match, + partial_match=objective_target_identifier_filter.partial_match, + ) + ) try: entries: Sequence[ScenarioResultEntry] = self._query_entries( diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 7bd05b4f82..fa9487055e 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -177,15 +177,6 @@ def _get_message_pieces_prompt_metadata_conditions( condition = text(json_conditions).bindparams(**{key: str(value) for key, value in prompt_metadata.items()}) return [condition] - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Generate SQLAlchemy filter conditions for filtering by attack ID. - - Returns: - Any: A SQLAlchemy text condition with bound parameters. - """ - return text("JSON_EXTRACT(attack_identifier, '$.hash') = :attack_id").bindparams(attack_id=str(attack_id)) - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: """ Generate SQLAlchemy filter conditions for filtering seed prompts by metadata. @@ -199,6 +190,103 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) # Note: We do NOT convert values to string here, to allow integer comparison in JSON return text(json_conditions).bindparams(**dict(metadata.items())) + def _get_condition_json_property_match( + self, + *, + json_column: Any, + property_path: str, + value_to_match: str, + partial_match: bool = False, + ) -> Any: + extracted_value = func.lower(func.json_extract(json_column, property_path)) + if partial_match: + return extracted_value.like(f"%{value_to_match.lower()}%") + return extracted_value == value_to_match.lower() + + def _get_condition_json_array_match( + self, + *, + json_column: Any, + property_path: str, + sub_path: str | None = None, + array_to_match: Sequence[str], + ) -> Any: + array_expr = func.json_extract(json_column, property_path) + if len(array_to_match) == 0: + return or_( + json_column.is_(None), + array_expr.is_(None), + array_expr == "[]", + ) + + table_name = json_column.class_.__tablename__ + column_name = json_column.key + value_expression = f"LOWER(json_extract(value, '{sub_path}'))" if sub_path else "LOWER(value)" + + conditions = [] + for index, match_value in enumerate(array_to_match): + param_name = f"match_value_{index}" + bind_params: dict[str, str] = { + "property_path": property_path, + param_name: match_value.lower(), + } + conditions.append( + text( + f"""EXISTS(SELECT 1 FROM json_each( + json_extract("{table_name}".{column_name}, :property_path)) + WHERE {value_expression} = :{param_name})""" + ).bindparams(**bind_params) + ) + return and_(*conditions) + + def _get_unique_json_array_values( + self, + *, + json_column: Any, + path_to_array: str, + sub_path: str | None = None, + ) -> list[str]: + with closing(self.get_session()) as session: + if sub_path is None: + property_expr = func.json_extract(json_column, path_to_array) + rows = session.query(property_expr).filter(property_expr.isnot(None)).distinct().all() + else: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + rows = session.execute( + text( + f"""SELECT DISTINCT json_extract(j.value, :sub_path) AS value + FROM "{table_name}", + json_each(json_extract("{table_name}".{column_name}, :path_to_array)) AS j + WHERE json_extract(j.value, :sub_path) IS NOT NULL""" + ).bindparams( + path_to_array=path_to_array, + sub_path=sub_path, + ) + ).fetchall() + return sorted(row[0] for row in rows) + + def get_unique_attack_class_names(self) -> list[str]: + """ + SQLite implementation: extract unique class_name values from + the atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique attack class name strings. + """ + return super().get_unique_attack_class_names() + + def get_unique_converter_class_names(self) -> list[str]: + """ + SQLite implementation: extract unique converter class_name values + from the children.attack.children.request_converters array in the + atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique converter class name strings. + """ + return super().get_unique_converter_class_names() + def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. @@ -526,97 +614,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) return labels_subquery # noqa: RET504 - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - SQLite implementation for filtering AttackResults by attack class. - Uses json_extract() on the atomic_attack_identifier JSON column. - - Returns: - Any: A SQLAlchemy condition for filtering by attack class. - """ - return ( - func.json_extract(AttackResultEntry.atomic_attack_identifier, "$.children.attack.class_name") - == attack_class - ) - - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - SQLite implementation for filtering AttackResults by converter classes. - - Uses json_extract() on the atomic_attack_identifier JSON column. - - When converter_classes is empty, matches attacks with no converters - (children.attack.children.request_converters is absent or null in the JSON). - When non-empty, uses json_each() to check all specified classes are present - (AND logic, case-insensitive). - - Returns: - Any: A SQLAlchemy condition for filtering by converter classes. - """ - if len(converter_classes) == 0: - # Explicitly "no converters": match attacks where the converter list - # is absent, null, or empty in the stored JSON. - converter_json = func.json_extract( - AttackResultEntry.atomic_attack_identifier, - "$.children.attack.children.request_converters", - ) - return or_( - AttackResultEntry.atomic_attack_identifier.is_(None), - converter_json.is_(None), - converter_json == "[]", - ) - - conditions = [] - for i, cls in enumerate(converter_classes): - param_name = f"conv_cls_{i}" - conditions.append( - text( - f"""EXISTS(SELECT 1 FROM json_each( - json_extract("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters')) - WHERE LOWER(json_extract(value, '$.class_name')) = :{param_name})""" - ).bindparams(**{param_name: cls.lower()}) - ) - return and_(*conditions) - - def get_unique_attack_class_names(self) -> list[str]: - """ - SQLite implementation: extract unique class_name values from - the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique attack class name strings. - """ - with closing(self.get_session()) as session: - class_name_expr = func.json_extract( - AttackResultEntry.atomic_attack_identifier, "$.children.attack.class_name" - ) - rows = session.query(class_name_expr).filter(class_name_expr.isnot(None)).distinct().all() - return sorted(row[0] for row in rows) - - def get_unique_converter_class_names(self) -> list[str]: - """ - SQLite implementation: extract unique converter class_name values - from the children.attack.children.request_converters array in the - atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique converter class name strings. - """ - with closing(self.get_session()) as session: - rows = session.execute( - text( - """SELECT DISTINCT json_extract(j.value, '$.class_name') AS cls - FROM "AttackResultEntries", - json_each( - json_extract("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') - ) AS j - WHERE cls IS NOT NULL""" - ) - ).fetchall() - return sorted(row[0] for row in rows) - def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]: """ SQLite implementation: lightweight aggregate stats per conversation. @@ -710,27 +707,3 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any return and_( *[func.json_extract(ScenarioResultEntry.labels, f"$.{key}") == value for key, value in labels.items()] ) - - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any: - """ - SQLite implementation for filtering ScenarioResults by target endpoint. - Uses json_extract() function specific to SQLite. - - Returns: - Any: A SQLAlchemy subquery for filtering by target endpoint. - """ - return func.lower(func.json_extract(ScenarioResultEntry.objective_target_identifier, "$.endpoint")).like( - f"%{endpoint.lower()}%" - ) - - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any: - """ - SQLite implementation for filtering ScenarioResults by target model name. - Uses json_extract() function specific to SQLite. - - Returns: - Any: A SQLAlchemy subquery for filtering by target model name. - """ - return func.lower(func.json_extract(ScenarioResultEntry.objective_target_identifier, "$.model_name")).like( - f"%{model_name.lower()}%" - ) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 91367c3a1c..84cda0b409 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -9,6 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory import MemoryInterface +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( AttackOutcome, @@ -1175,15 +1176,6 @@ def test_get_attack_results_by_attack_class_no_match(sqlite_instance: MemoryInte assert len(results) == 0 -def test_get_attack_results_by_attack_class_case_sensitive(sqlite_instance: MemoryInterface): - """Test that attack_class filter is case-sensitive (exact match).""" - ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") - sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - - results = sqlite_instance.get_attack_results(attack_class="crescendoattack") - assert len(results) == 0 - - def test_get_attack_results_by_attack_class_no_identifier(sqlite_instance: MemoryInterface): """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_class filter.""" ar1 = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} @@ -1352,3 +1344,55 @@ def test_get_unique_converter_class_names_skips_no_converters(sqlite_instance: M result = sqlite_instance.get_unique_converter_class_names() assert result == ["Base64Converter"] + + +def test_get_attack_results_by_attack_identifier_filter_hash(sqlite_instance: MemoryInterface): + """Test filtering attack results by AttackIdentifierFilter with hash.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + # Filter by hash of ar1's attack identifier + results = sqlite_instance.get_attack_results( + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", + value_to_match=ar1.atomic_attack_identifier.hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].conversation_id == "conv_1" + + +def test_get_attack_results_by_attack_identifier_filter_class_name(sqlite_instance: MemoryInterface): + """Test filtering attack results by AttackIdentifierFilter with class_name.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + # Filter by partial attack class name + results = sqlite_instance.get_attack_results( + attack_identifier_filter=IdentifierFilter( + property_path="$.children.attack.class_name", + value_to_match="Crescendo", + partial_match=True, + ), + ) + assert len(results) == 2 + assert {r.conversation_id for r in results} == {"conv_1", "conv_3"} + + +def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance: MemoryInterface): + """Test that AttackIdentifierFilter returns empty when nothing matches.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) + + results = sqlite_instance.get_attack_results( + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", + value_to_match="nonexistent_hash", + partial_match=False, + ), + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 67a4292f87..eec4d3d88a 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -14,6 +14,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.models import ( Message, MessagePiece, @@ -1248,3 +1249,112 @@ def test_get_request_from_response_raises_error_for_sequence_less_than_one(sqlit with pytest.raises(ValueError, match="The provided request does not have a preceding request \\(sequence < 1\\)."): sqlite_instance.get_request_from_response(response=response_without_request) + + +def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryInterface): + attack1 = PromptSendingAttack(objective_target=get_mock_target()) + attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) + + entries = [ + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello 1", + attack_identifier=attack1.get_identifier(), + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="assistant", + original_value="Hello 2", + attack_identifier=attack2.get_identifier(), + ) + ), + ] + + sqlite_instance._insert_entries(entries=entries) + + # Filter by exact attack hash + results = sqlite_instance.get_message_pieces( + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", + value_to_match=attack1.get_identifier().hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].original_value == "Hello 1" + + # No match + results = sqlite_instance.get_message_pieces( + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", + value_to_match="nonexistent_hash", + partial_match=False, + ), + ) + assert len(results) == 0 + + +def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryInterface): + target_id_1 = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="AzureChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + entries = [ + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello OpenAI", + prompt_target_identifier=target_id_1, + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello Azure", + prompt_target_identifier=target_id_2, + ) + ), + ] + + sqlite_instance._insert_entries(entries=entries) + + # Filter by target hash + results = sqlite_instance.get_message_pieces( + prompt_target_identifier_filter=IdentifierFilter( + property_path="$.hash", + value_to_match=target_id_1.hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].original_value == "Hello OpenAI" + + # Filter by endpoint partial match + results = sqlite_instance.get_message_pieces( + prompt_target_identifier_filter=IdentifierFilter( + property_path="$.endpoint", + value_to_match="openai", + partial_match=True, + ), + ) + assert len(results) == 1 + assert results[0].original_value == "Hello OpenAI" + + # No match + results = sqlite_instance.get_message_pieces( + prompt_target_identifier_filter=IdentifierFilter( + property_path="$.hash", + value_to_match="nonexistent", + partial_match=False, + ), + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index e513e8b873..ee2933b70a 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -9,6 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.models import ( AttackOutcome, AttackResult, @@ -645,3 +646,116 @@ def test_combined_filters(sqlite_instance: MemoryInterface): assert len(results) == 1 assert results[0].scenario_identifier.pyrit_version == "0.5.0" assert "gpt-4" in results[0].objective_target_identifier.params["model_name"] + + +def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: MemoryInterface): + """Test filtering scenario results by identifier filter.""" + target_id_1 = ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="Azure", + class_module="test", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + attack_result1 = create_attack_result("conv_1", "Objective 1") + attack_result2 = create_attack_result("conv_2", "Objective 2") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario OpenAI", scenario_version=1), + objective_target_identifier=target_id_1, + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + scenario2 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario Azure", scenario_version=1), + objective_target_identifier=target_id_2, + attack_results={"Attack2": [attack_result2]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1, scenario2]) + + # Filter by target hash + results = sqlite_instance.get_scenario_results( + objective_target_identifier_filter=IdentifierFilter( + property_path="$.hash", + value_to_match=target_id_1.hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].scenario_identifier.name == "Scenario OpenAI" + + +def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instance: MemoryInterface): + """Test filtering scenario results by identifier filter with endpoint.""" + target_id_1 = ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="Azure", + class_module="test", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + attack_result1 = create_attack_result("conv_1", "Objective 1") + attack_result2 = create_attack_result("conv_2", "Objective 2") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario OpenAI", scenario_version=1), + objective_target_identifier=target_id_1, + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + scenario2 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario Azure", scenario_version=1), + objective_target_identifier=target_id_2, + attack_results={"Attack2": [attack_result2]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1, scenario2]) + + # Filter by endpoint partial match + results = sqlite_instance.get_scenario_results( + objective_target_identifier_filter=IdentifierFilter( + property_path="$.endpoint", + value_to_match="openai", + partial_match=True, + ), + ) + assert len(results) == 1 + assert results[0].scenario_identifier.name == "Scenario OpenAI" + + +def test_get_scenario_results_by_target_identifier_filter_no_match(sqlite_instance: MemoryInterface): + """Test that TargetIdentifierFilter returns empty when nothing matches.""" + attack_result1 = create_attack_result("conv_1", "Objective 1") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Test Scenario", scenario_version=1), + objective_target_identifier=ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com"}, + ), + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1]) + + results = sqlite_instance.get_scenario_results( + objective_target_identifier_filter=IdentifierFilter( + property_path="$.hash", + value_to_match="nonexistent_hash", + partial_match=False, + ), + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 6087af1418..2c90b18313 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -13,6 +13,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.models import ( MessagePiece, Score, @@ -227,3 +228,77 @@ async def test_get_seeds_no_filters(sqlite_instance: MemoryInterface): assert len(result) == 2 assert result[0].value == "prompt1" assert result[1].value == "prompt2" + + +def test_get_scores_by_scorer_identifier_filter( + sqlite_instance: MemoryInterface, + sample_conversation_entries: Sequence[PromptMemoryEntry], +): + prompt_id = sample_conversation_entries[0].id + sqlite_instance._insert_entries(entries=sample_conversation_entries) + + score_a = Score( + score_value="0.9", + score_value_description="High", + score_type="float_scale", + score_category=["cat_a"], + score_rationale="Rationale A", + score_metadata={}, + scorer_class_identifier=_test_scorer_id("ScorerAlpha"), + message_piece_id=prompt_id, + ) + score_b = Score( + score_value="0.1", + score_value_description="Low", + score_type="float_scale", + score_category=["cat_b"], + score_rationale="Rationale B", + score_metadata={}, + scorer_class_identifier=_test_scorer_id("ScorerBeta"), + message_piece_id=prompt_id, + ) + + sqlite_instance.add_scores_to_memory(scores=[score_a, score_b]) + + # Filter by exact class_name match + results = sqlite_instance.get_scores( + scorer_identifier_filter=IdentifierFilter( + property_path="$.class_name", + value_to_match="ScorerAlpha", + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].score_value == "0.9" + + # Filter by partial class_name match + results = sqlite_instance.get_scores( + scorer_identifier_filter=IdentifierFilter( + property_path="$.class_name", + value_to_match="Scorer", + partial_match=True, + ), + ) + assert len(results) == 2 + + # Filter by hash + scorer_hash = score_a.scorer_class_identifier.hash + results = sqlite_instance.get_scores( + scorer_identifier_filter=IdentifierFilter( + property_path="$.hash", + value_to_match=scorer_hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].score_value == "0.9" + + # No match + results = sqlite_instance.get_scores( + scorer_identifier_filter=IdentifierFilter( + property_path="$.class_name", + value_to_match="NonExistent", + partial_match=False, + ), + ) + assert len(results) == 0