Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,9 @@ def get_message_pieces(
Exception: If there is an error retrieving the prompts,
an exception is logged and an empty list is returned.
"""
if prompt_ids is not None and len(prompt_ids) == 0:
return []

conditions = []
if attack_id:
conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id)))
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/memory/memory_interface/test_interface_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ def test_get_message_pieces_uuid_and_string_ids(sqlite_instance: MemoryInterface
assert str(single_str_result[0].id) == str(uuid3)


def test_get_message_pieces_empty_prompt_ids_returns_empty(sqlite_instance: MemoryInterface):
piece = MessagePiece(
id=uuid.uuid4(),
role="user",
original_value="Test prompt",
converted_value="Test prompt",
)
sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece])

assert sqlite_instance.get_message_pieces(prompt_ids=[]) == []


def test_duplicate_memory(sqlite_instance: MemoryInterface):
attack1 = PromptSendingAttack(objective_target=get_mock_target())
attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2"))
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/memory/memory_interface/test_interface_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,31 @@ def test_add_score_get_score(
assert db_score[0].message_piece_id == prompt_id


def test_get_prompt_scores_empty_prompt_ids_returns_empty(sqlite_instance: MemoryInterface):
prompt_id = uuid4()
piece = MessagePiece(
id=prompt_id,
role="user",
original_value="original prompt text",
converted_value="Hello, how are you?",
)
sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece])

score = Score(
score_value=str(0.8),
score_value_description="High score",
score_type="float_scale",
score_category=["test"],
score_rationale="Test score",
score_metadata={"test": "metadata"},
scorer_class_identifier=_test_scorer_id("TestScorer"),
message_piece_id=prompt_id,
)
sqlite_instance.add_scores_to_memory(scores=[score])

assert sqlite_instance.get_prompt_scores(prompt_ids=[]) == []


def test_add_score_duplicate_prompt(sqlite_instance: MemoryInterface):
# Ensure that scores of duplicate prompts are linked back to the original
original_id = uuid4()
Expand Down
Loading