Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions dreadnode/airt/attack/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from dreadnode.airt.target.base import Target
from dreadnode.eval.hooks.base import EvalHook
from dreadnode.meta import Config
from dreadnode.optimization.study import OutputT as Out
from dreadnode.optimization.study import Study
from dreadnode.optimization.trial import CandidateT as In
from dreadnode.task import Task

In = t.TypeVar("In")
Out = t.TypeVar("Out")


class Attack(Study[In, Out]):
"""
A declarative configuration for executing an AIRT attack.

Attack automatically derives its task from the target.
"""

model_config = ConfigDict(arbitrary_types_allowed=True, use_attribute_docstrings=True)
Expand All @@ -23,16 +25,12 @@ class Attack(Study[In, Out]):

tags: list[str] = Config(default_factory=lambda: ["attack"])
"""A list of tags associated with the attack for logging."""

hooks: list[EvalHook] = Field(default_factory=list, exclude=True, repr=False)
"""Hooks to run at various points in the attack lifecycle."""

# Override the task factory as the target will replace it.
task_factory: t.Callable[[In], Task[..., Out]] = Field( # type: ignore[assignment]
default_factory=lambda: None,
repr=False,
init=False,
)

def model_post_init(self, context: t.Any) -> None:
self.task_factory = self.target.task_factory
"""Initialize attack by deriving task from target."""
if self.task is None:
self.task = self.target.task # type: ignore[attr-defined]
super().model_post_init(context)
6 changes: 0 additions & 6 deletions dreadnode/airt/target/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import typing_extensions as te

from dreadnode.meta import Model
from dreadnode.task import Task

In = te.TypeVar("In", default=t.Any)
Out = te.TypeVar("Out", default=t.Any)
Expand All @@ -18,8 +17,3 @@ class Target(Model, abc.ABC, t.Generic[In, Out]):
def name(self) -> str:
"""Returns the name of the target."""
raise NotImplementedError

@abc.abstractmethod
def task_factory(self, input: In) -> Task[..., Out]:
"""Creates a Task that will run the given input against the target."""
raise NotImplementedError
8 changes: 1 addition & 7 deletions dreadnode/airt/target/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import ConfigDict

from dreadnode.airt.target.base import In, Out, Target
from dreadnode.airt.target.base import Out, Target
from dreadnode.common_types import Unset
from dreadnode.meta import Config
from dreadnode.task import Task
Expand Down Expand Up @@ -39,9 +39,3 @@ def model_post_init(self, context: t.Any) -> None:

if self.input_param_name is None:
raise ValueError(f"Could not determine input parameter for {self.task!r}")

def task_factory(self, input: In) -> Task[..., Out]:
task = self.task
if self.input_param_name is not None:
task = self.task.configure(**{self.input_param_name: input})
return task.with_(tags=["target"], append=True)
28 changes: 6 additions & 22 deletions dreadnode/airt/target/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,30 +39,14 @@ def generator(self) -> rg.Generator:
def name(self) -> str:
return self.generator.to_identifier(short=True).split("/")[-1]

def task_factory(self, input: DnMessage) -> Task[[], DnMessage]:
@cached_property
def task(self) -> Task[[DnMessage], DnMessage]:
"""
create a task that:
1. Takes dn.Message as input (auto-logged via to_serializable())
2. Converts to rg.Message only for LLM API call
3. Returns dn.Message with full multimodal content (text/images/audio/video)

Args:
input: The dn.Message to send to the LLM

Returns:
Task that executes the LLM call and returns dn.Message
Task for LLM generation.

Raises:
TypeError: If input is not a dn.Message
ValueError: If the message has no content
Message input will come from dataset (injected by Study),
not from task defaults.
"""
if not isinstance(input, DnMessage):
raise TypeError(f"Expected dn.Message, got {type(input).__name__}")

if not input.content:
raise ValueError("Message must have at least one content part")

dn_message = input
params = (
self.params
if isinstance(self.params, rg.GenerateParams)
Expand All @@ -73,7 +57,7 @@ def task_factory(self, input: DnMessage) -> Task[[], DnMessage]:

@task(name=f"target - {self.name}", tags=["target"])
async def generate(
message: DnMessage = dn_message,
message: DnMessage,
params: rg.GenerateParams = params,
) -> DnMessage:
"""Execute LLM generation task."""
Expand Down
12 changes: 7 additions & 5 deletions dreadnode/eval/hooks/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def hook(event: "EvalEvent") -> "EvalReaction | None": # noqa: PLR0911
if create_task:
from dreadnode import task as dn_task

task_kwargs = event.task_kwargs
input_data = event.task_kwargs

@dn_task(
name=f"transform - input ({len(transforms)} transforms)",
Expand All @@ -44,11 +44,11 @@ async def hook(event: "EvalEvent") -> "EvalReaction | None": # noqa: PLR0911
log_output=True,
)
async def apply_task(
data: dict[str, t.Any] = task_kwargs, # Use extracted variable
data: dict[str, t.Any],
) -> dict[str, t.Any]:
return await apply_transforms_to_kwargs(data, transforms)

transformed = await apply_task()
transformed = await apply_task(input_data)
return ModifyInput(task_kwargs=transformed)

# Direct application
Expand All @@ -73,10 +73,12 @@ async def apply_task(
log_inputs=True,
log_output=True,
)
async def apply_task(data: t.Any = output_data) -> t.Any: # Use extracted variable
async def apply_task(
data: t.Any,
) -> t.Any:
return await apply_transforms_to_value(data, transforms)

transformed = await apply_task()
transformed = await apply_task(output_data)
return ModifyOutput(output=transformed)

# Direct application
Expand Down
4 changes: 1 addition & 3 deletions dreadnode/optimization/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def format_study(study: "Study") -> RenderableType:
if isinstance(study, Attack):
details.add_row(Text("Target", justify="right"), repr(study.target))
else:
details.add_row(
Text("Task Factory", justify="right"), get_callable_name(study.task_factory)
)
details.add_row(Text("Task Factory", justify="right"), get_callable_name(study.task))

details.add_row(Text("Search Strategy", justify="right"), study.search_strategy.name)

Expand Down
135 changes: 100 additions & 35 deletions dreadnode/optimization/study.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import asyncio
import contextlib
import contextvars
import inspect
import typing as t
from pathlib import Path

import typing_extensions as te
from loguru import logger
from pydantic import ConfigDict, Field, FilePath, SkipValidation, computed_field

from dreadnode import log_inputs, log_metrics, log_outputs, task_span
from dreadnode.common_types import AnyDict
from dreadnode.data_types.message import Message
from dreadnode.error import AssertionFailedError
from dreadnode.eval import InputDataset
from dreadnode.eval.dataset import load_dataset
from dreadnode.eval.eval import Eval
from dreadnode.eval.hooks.base import EvalHook
from dreadnode.meta import Config, Model
Expand Down Expand Up @@ -65,13 +70,14 @@ class Study(Model, t.Generic[CandidateT, OutputT]):

search_strategy: SkipValidation[Search[CandidateT]]
"""The search strategy to use for suggesting new trials."""
task_factory: SkipValidation[t.Callable[[CandidateT], Task[..., OutputT]]]
"""A function that accepts a trial candidate and returns a configured Task ready for evaluation."""
probe_task_factory: SkipValidation[t.Callable[[CandidateT], Task[..., OutputT]] | None] = None
"""
An optional function that accepts a probe candidate and returns a Task.

Otherwise the main task_factory will be used for both full evaluation Trials and probe Trials.
task: SkipValidation[Task[..., OutputT]] | None = None
"""The task to evaluate with optimized candidates."""

candidate_param: str | None = None
"""
Task parameter name for candidate injection.
If None, inferred from task signature or candidate type.
"""
objectives: t.Annotated[ObjectivesLike[OutputT], Config(expose_as=None)]
"""
Expand Down Expand Up @@ -165,7 +171,7 @@ def with_(
description: str | None = None,
tags: list[str] | None = None,
search_strategy: Search[CandidateT] | None = None,
task_factory: t.Callable[[CandidateT], Task[..., OutputT]] | None = None,
task: Task[..., OutputT] | None = None,
objectives: ObjectivesLike[OutputT] | None = None,
directions: list[Direction] | None = None,
dataset: InputDataset[t.Any] | list[AnyDict] | FilePath | None = None,
Expand All @@ -186,7 +192,7 @@ def with_(
new.name_ = name or new.name
new.description = description or new.description
new.search_strategy = search_strategy or new.search_strategy
new.task_factory = task_factory or new.task_factory
new.task = task or new.task
new.dataset = dataset if dataset is not None else new.dataset
new.concurrency = concurrency or new.concurrency
new.max_evals = max_trials or new.max_evals
Expand Down Expand Up @@ -240,23 +246,83 @@ def add_stop_condition(self, condition: StudyStopCondition[CandidateT]) -> te.Se
self.stop_conditions.append(condition)
return self

def _resolve_dataset(self, dataset: t.Any) -> list[AnyDict]:
"""
Resolve dataset to a list in memory.
Handles list, file path, or callable datasets.
"""
if dataset is None:
return [{}]

# Already a list
if isinstance(dataset, list):
return dataset

# File path
if isinstance(dataset, (Path, str, FilePath)):
return load_dataset(dataset)

# Callable
if callable(dataset):
result = dataset()
if inspect.isawaitable(result):
raise ValueError(
"Async dataset callables not supported with COA 1 "
"(requires eager materialization)"
)
return list(result) if not isinstance(result, list) else result

return [{}]

def _infer_candidate_param(self, task: Task[..., OutputT], candidate: CandidateT) -> str:
"""
Infer task parameter name for candidate injection.

Priority:
1. Explicit self.candidate_param if set
2. "message" if candidate is Message type
3. First non-config param from task signature
4. Fallback to "input"
"""

# Priority 1: Explicit override
if self.candidate_param:
return self.candidate_param

# Priority 2: Type-based convention
if isinstance(candidate, Message):
return "message"

# Priority 3: Signature inspection
try:
for param_name, param in task.signature.parameters.items():
# Skip config params (those with defaults)
if param.default == inspect.Parameter.empty:
logger.debug(f"Inferred candidate parameter: {param_name}")
return param_name
except Exception as e: # noqa: BLE001
logger.trace(f"Could not infer parameter from signature: {e}")

# Priority 4: Universal fallback
logger.debug("Using fallback candidate parameter: input")
return "input"

async def _process_trial(
self, trial: Trial[CandidateT]
) -> t.AsyncIterator[StudyEvent[CandidateT]]:
"""
Checks constraints and evaluates a single trial, returning a list of events.
"""
from dreadnode import log_inputs, log_metrics, log_outputs, task_span

logger.debug(
f"Processing trial: id={trial.id}, step={trial.step}, is_probe={trial.is_probe}"
)
task = self.task

if task is None:
raise ValueError(
"Study.task is required but was not set. "
"For Attack, this should be set automatically from target. "
"For Study, pass task explicitly."
)

task_factory = (
self.probe_task_factory
if trial.is_probe and self.probe_task_factory
else self.task_factory
)
dataset = trial.dataset or self.dataset or [{}]
probe_or_trial = "probe" if trial.is_probe else "trial"

Expand Down Expand Up @@ -302,17 +368,14 @@ def log_trial(trial: Trial[CandidateT]) -> None:
# Check constraints
await self._check_constraints(trial.candidate, trial)

# Create task
task = task_factory(trial.candidate)

# Get base scorers
scorers: list[Scorer[OutputT]] = [
scorer
for scorer in fit_objectives(self.objectives)
if isinstance(scorer, Scorer)
]

# Run evaluation (transforms are applied inside Eval now)
# Run evaluation (candidate injected via dataset augmentation)
trial.eval_result = await self._run_evaluation(task, dataset, scorers, trial)

# Extract final scores
Expand Down Expand Up @@ -370,26 +433,28 @@ async def _run_evaluation(
trial: Trial[CandidateT],
) -> t.Any:
"""Run the evaluation with the given task, dataset, and scorers."""
resolved_dataset = self._resolve_dataset(dataset)
param_name = self._infer_candidate_param(task, trial.candidate)

logger.debug(
f"Evaluating trial: "
f"trial_id={trial.id}, "
f"step={trial.step}, "
f"dataset_size={len(dataset) if isinstance(dataset, t.Sized) else '<unknown>'}, "
f"task={task.name}"
f"Augmenting {len(resolved_dataset)} dataset rows with candidate "
f"as parameter: {param_name}"
)
logger.trace(f"Candidate: {trial.candidate!r}")

# if dataset == [{}] or (isinstance(dataset, list) and len(dataset) == 1 and not dataset[0]):
# # Dataset is empty - this is a Study/Attack where the candidate IS the input
# dataset = [{"message": trial.candidate}]
# dataset_input_mapping = ["message"]
# else:
# dataset_input_mapping = None
# Augment every row with the candidate
augmented_dataset = [{**row, param_name: trial.candidate} for row in resolved_dataset]

# Warn on collisions
if resolved_dataset and param_name in resolved_dataset[0]:
logger.warning(
f"Parameter '{param_name}' already exists in dataset - "
f"candidate will override existing values"
)

evaluator = Eval(
task=task,
dataset=dataset,
# dataset_input_mapping=dataset_input_mapping,
dataset=augmented_dataset,
dataset_input_mapping=[param_name],
scorers=scorers,
hooks=self.hooks,
max_consecutive_errors=self.max_consecutive_errors,
Expand Down
Loading