Skip to content

Commit da1143a

Browse files
bashirpartoviBashir Partovi
andauthored
FEAT Refactored XPIA orchestrator as a workflow (microsoft#1062)
Co-authored-by: Bashir Partovi <bpartovi@microsoft.com>
1 parent 775e9c5 commit da1143a

7 files changed

Lines changed: 1463 additions & 39 deletions

File tree

doc/api.rst

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,6 @@ API Reference
1414

1515
ConversationAnalytics
1616

17-
:py:mod:`pyrit.attacks`
18-
============================
19-
20-
.. automodule:: pyrit.executor.attack
21-
:no-members:
22-
:no-inherited-members:
23-
24-
.. autosummary::
25-
:nosignatures:
26-
:toctree: _autosummary/
27-
28-
AttackAdversarialConfig
29-
AttackContext
30-
AttackConverterConfig
31-
AttackExecutor
32-
AttackScoringConfig
33-
AttackStrategy
34-
ContextComplianceAttack
35-
ConversationSession
36-
CrescendoAttack
37-
FlipAttack
38-
ManyShotJailbreakAttack
39-
MultiTurnAttackContext
40-
PromptSendingAttack
41-
RTOSystemPromptPaths
42-
RedTeamingAttack
43-
RolePlayAttack
44-
SingleTurnAttackContext
45-
TAPAttack
46-
TAPAttackContext
47-
TAPAttackResult
48-
TreeOfAttacksWithPruningAttack
49-
SkeletonKeyAttack
50-
ConsoleAttackResultPrinter
51-
5217
:py:mod:`pyrit.auth`
5318
====================
5419

@@ -208,6 +173,82 @@ API Reference
208173
RateLimitException
209174
remove_markdown_json
210175

176+
:py:mod:`pyrit.executor.attack`
177+
===============================
178+
179+
.. automodule:: pyrit.executor.attack
180+
:no-members:
181+
:no-inherited-members:
182+
183+
.. autosummary::
184+
:nosignatures:
185+
:toctree: _autosummary/
186+
187+
AttackAdversarialConfig
188+
AttackContext
189+
AttackConverterConfig
190+
AttackExecutor
191+
AttackScoringConfig
192+
AttackStrategy
193+
ContextComplianceAttack
194+
ConversationSession
195+
CrescendoAttack
196+
FlipAttack
197+
ManyShotJailbreakAttack
198+
MultiTurnAttackContext
199+
PromptSendingAttack
200+
RTOSystemPromptPaths
201+
RedTeamingAttack
202+
RolePlayAttack
203+
SingleTurnAttackContext
204+
TAPAttack
205+
TAPAttackContext
206+
TAPAttackResult
207+
TreeOfAttacksWithPruningAttack
208+
SkeletonKeyAttack
209+
ConsoleAttackResultPrinter
210+
211+
:py:mod:`pyrit.executor.promptgen`
212+
==================================
213+
214+
.. automodule:: pyrit.executor.promptgen
215+
:no-members:
216+
:no-inherited-members:
217+
218+
.. autosummary::
219+
:nosignatures:
220+
:toctree: _autosummary/
221+
222+
AnecdoctorContext
223+
AnecdoctorGenerator
224+
AnecdoctorResult
225+
FuzzerContext
226+
FuzzerResult
227+
FuzzerGenerator
228+
FuzzerResultPrinter
229+
PromptGeneratorStrategy
230+
PromptGeneratorStrategyContext
231+
PromptGeneratorStrategyResult
232+
233+
:py:mod:`pyrit.executor.workflow`
234+
=================================
235+
236+
.. automodule:: pyrit.executor.workflow
237+
:no-members:
238+
:no-inherited-members:
239+
240+
.. autosummary::
241+
:nosignatures:
242+
:toctree: _autosummary/
243+
244+
XPIAContext
245+
XPIAResult
246+
XPIAWorkflow
247+
XPIATestWorkflow
248+
XPIAManualProcessingWorkflow
249+
XPIAProcessingCallback
250+
XPIAStatus
251+
211252
:py:mod:`pyrit.memory`
212253
======================
213254

pyrit/common/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def warn_if_set(
112112
for field_name in unused_fields:
113113
# Get the field value from the config object
114114
if not hasattr(config, field_name):
115-
log.warning(f"Field '{field_name}' does not exist in {config_name}. " f"Skipping unused parameter check.")
115+
log.warning(f"Field '{field_name}' does not exist in {config_name}. Skipping unused parameter check.")
116116
continue
117117

118118
param_value = getattr(config, field_name)
@@ -127,9 +127,7 @@ def warn_if_set(
127127
is_set = True
128128

129129
if is_set:
130-
log.warning(
131-
f"{field_name} was provided in {config_name} but is not used. " f"This parameter will be ignored."
132-
)
130+
log.warning(f"{field_name} was provided in {config_name} but is not used. This parameter will be ignored.")
133131

134132

135133
_T = TypeVar("_T")
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
5+
from pyrit.executor.workflow.xpia import (
6+
XPIAContext,
7+
XPIAResult,
8+
XPIAWorkflow,
9+
XPIATestWorkflow,
10+
XPIAManualProcessingWorkflow,
11+
XPIAProcessingCallback,
12+
XPIAStatus,
13+
)
14+
15+
__all__ = [
16+
"XPIAContext",
17+
"XPIAResult",
18+
"XPIAWorkflow",
19+
"XPIATestWorkflow",
20+
"XPIAManualProcessingWorkflow",
21+
"XPIAProcessingCallback",
22+
"XPIAStatus",
23+
]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
from pyrit.executor.workflow.core.workflow_strategy import (
5+
WorkflowContext,
6+
WorkflowResult,
7+
WorkflowStrategy,
8+
)
9+
10+
__all__ = [
11+
"WorkflowContext",
12+
"WorkflowResult",
13+
"WorkflowStrategy",
14+
]
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
from __future__ import annotations
5+
6+
import logging
7+
from abc import ABC
8+
from dataclasses import dataclass
9+
from typing import Optional, TypeVar
10+
11+
from pyrit.common.logger import logger
12+
from pyrit.executor.core.strategy import (
13+
Strategy,
14+
StrategyContext,
15+
StrategyEvent,
16+
StrategyEventData,
17+
StrategyEventHandler,
18+
)
19+
from pyrit.models import StrategyResult
20+
21+
WorkflowContextT = TypeVar("WorkflowContextT", bound="WorkflowContext")
22+
WorkflowResultT = TypeVar("WorkflowResultT", bound="WorkflowResult")
23+
24+
25+
@dataclass
26+
class WorkflowContext(StrategyContext, ABC):
27+
"""Base class for all workflow contexts"""
28+
29+
pass
30+
31+
32+
@dataclass
33+
class WorkflowResult(StrategyResult, ABC):
34+
"""Base class for all workflow results"""
35+
36+
pass
37+
38+
39+
class _DefaultWorkflowEventHandler(StrategyEventHandler[WorkflowContextT, WorkflowResultT]):
40+
"""
41+
Default event handler for workflow strategies.
42+
Handles events during the execution of a workflow strategy.
43+
"""
44+
45+
def __init__(self, logger: logging.Logger = logger):
46+
"""
47+
Initialize the default event handler with a logger.
48+
49+
Args:
50+
logger (logging.Logger): Logger instance for logging events.
51+
"""
52+
self._logger = logger
53+
self._events = {
54+
StrategyEvent.ON_PRE_VALIDATE: self._on_pre_validate,
55+
StrategyEvent.ON_POST_VALIDATE: self._on_post_validate,
56+
StrategyEvent.ON_PRE_SETUP: self._on_pre_setup,
57+
StrategyEvent.ON_POST_SETUP: self._on_post_setup,
58+
StrategyEvent.ON_PRE_EXECUTE: self._on_pre_execute,
59+
StrategyEvent.ON_POST_EXECUTE: self._on_post_execute,
60+
StrategyEvent.ON_PRE_TEARDOWN: self._on_pre_teardown,
61+
StrategyEvent.ON_POST_TEARDOWN: self._on_post_teardown,
62+
StrategyEvent.ON_ERROR: self._on_error,
63+
}
64+
65+
async def on_event(self, event_data: StrategyEventData[WorkflowContextT, WorkflowResultT]) -> None:
66+
"""
67+
Handle an event during the workflow strategy execution.
68+
69+
Args:
70+
event_data: The event data containing context and result.
71+
"""
72+
if event_data.event in self._events:
73+
handler = self._events[event_data.event]
74+
await handler(event_data)
75+
76+
async def _on_pre_validate(self, event_data: StrategyEventData[WorkflowContextT, WorkflowResultT]) -> None:
77+
self._logger.debug(f"Starting validation for workflow {event_data.strategy_name}")
78+
79+
async def _on_post_validate(self, event_data: StrategyEventData[WorkflowContextT, WorkflowResultT]) -> None:
80+
self._logger.debug(f"Validation completed for workflow {event_data.strategy_name}")
81+
82+
async def _on_pre_setup(self, event_data: StrategyEventData[WorkflowContextT, WorkflowResultT]) -> None:
83+
self._logger.debug(f"Starting setup for workflow {event_data.strategy_name}")
84+
85+
async def _on_post_setup(self, event_data: StrategyEventData[WorkflowContextT, WorkflowResultT]) -> None:
86+
self._logger.debug(f"Setup completed for workflow {event_data.strategy_name}")
87+
88+
async def _on_pre_execute(self, event_data: StrategyEventData[WorkflowContextT, WorkflowResultT]) -> None:
89+
self._logger.info(f"Starting execution of workflow {event_data.strategy_name}")
90+
91+
async def _on_post_execute(self, event_data: StrategyEventData[WorkflowContextT, WorkflowResultT]) -> None:
92+
self._logger.info(f"Workflow {event_data.strategy_name} completed.")
93+
94+
async def _on_pre_teardown(self, event_data: StrategyEventData[WorkflowContextT, WorkflowResultT]) -> None:
95+
self._logger.debug(f"Starting teardown for workflow {event_data.strategy_name}")
96+
97+
async def _on_post_teardown(self, event_data: StrategyEventData[WorkflowContextT, WorkflowResultT]) -> None:
98+
self._logger.debug(f"Teardown completed for workflow {event_data.strategy_name}")
99+
100+
async def _on_error(self, event_data: StrategyEventData[WorkflowContextT, WorkflowResultT]) -> None:
101+
self._logger.error(
102+
f"Error in workflow {event_data.strategy_name}: {event_data.error}", exc_info=event_data.error
103+
)
104+
105+
106+
class WorkflowStrategy(Strategy[WorkflowContextT, WorkflowResultT], ABC):
107+
"""
108+
Abstract base class for workflow strategies.
109+
Defines the interface for executing workflows and handling results.
110+
"""
111+
112+
def __init__(
113+
self,
114+
*,
115+
context_type: type[WorkflowContextT],
116+
logger: logging.Logger = logger,
117+
event_handler: Optional[StrategyEventHandler[WorkflowContextT, WorkflowResultT]] = None,
118+
):
119+
"""
120+
Initialize the workflow strategy with a specific context type and logger.
121+
122+
Args:
123+
context_type: The type of context this strategy operates on.
124+
logger: Logger instance for logging events.
125+
event_handler: Optional custom event handler for workflow events.
126+
"""
127+
default_handler = _DefaultWorkflowEventHandler[WorkflowContextT, WorkflowResultT](logger=logger)
128+
super().__init__(
129+
context_type=context_type,
130+
event_handler=event_handler or default_handler,
131+
logger=logger,
132+
)

0 commit comments

Comments
 (0)