Skip to content

Commit 1eef610

Browse files
authored
Merge pull request #21 from Agent-One-Lab/algorithm
Update qa reward and fix async retrieval bug
2 parents c53bac3 + 54f3acc commit 1eef610

4 files changed

Lines changed: 83 additions & 3 deletions

File tree

docs/features/reward_system.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ We put reward calculation into the agent side instead of trainer side and use a
55
2. Reward calculation can be designed to be asynchronous for efficiency.
66

77
### Definition
8-
Similar to tools, we can decide whether to use environments in the reward definition. The return should either be a value, or a dictionary containing `reward` as one of keys. We can use decorator `@tool` or inherit from the `BaseReward` class.
8+
Similar to tools, we can decide whether to use environments in the reward definition. The return should either be a value, or a dictionary containing `reward` as one of keys. We can use decorator `@tool` or inherit from the `BaseReward` class. Any additional keys in the returned dict (e.g. `em`, `f1`, `fmt`) are passed through and documented in training and validation.
9+
10+
911

1012
```python
1113
@reward(name="qa_f1_reward")
@@ -105,3 +107,9 @@ def summary_reward(final_response, length_penalty, max_length):
105107
else:
106108
return 1.0
107109
```
110+
111+
## Return Values
112+
113+
Each a `float` value or a dictionary containing `reward` as key should be returned. If the return value is `float`, it is directly used as rewards. If a dictionary is returned, the `reward` is used as rewards. While other keys are still documented.
114+
115+
Extra keys (besides `reward`) are logged as `reward_extra/{key}/mean`, `reward_extra/{key}/max`, `reward_extra/{key}/min` in the metrics produced by `compute_data_metrics` (`verl/verl/trainer/ppo/metric_utils.py`).

src/agentfly/rewards/qa_reward.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,78 @@ def qa_f1_reward_tool(final_response: str, answer: str, trajectory: List[str]) -
149149
return rewards_dict
150150

151151

152+
def _extract_answer_tag(text: str) -> str:
153+
"""Extract content between <answer> and </answer>, or return original if not present."""
154+
match = re.search(r"<answer>\s*(.*?)\s*</answer>", text, re.DOTALL)
155+
return match.group(1).strip() if match else text
156+
157+
158+
def _format_ok(final_response: str, trajectory: List) -> tuple:
159+
"""True if final_response has <answer>...</answer>, trajectory has tool calling, and all assistant turns except the last have <think>/</think>."""
160+
has_answer_tags = "<answer>" in final_response and "</answer>" in final_response
161+
if not has_answer_tags or not trajectory:
162+
return False, False, False
163+
has_tool_calling = any(
164+
isinstance(msg, dict) and msg.get("role") == "tool" for msg in trajectory
165+
)
166+
# Collect assistant turns; only previous (non-last) ones must have think
167+
assistant_turns = []
168+
for msg in trajectory:
169+
if isinstance(msg, dict):
170+
if msg.get("role") == "assistant":
171+
content = msg.get("content") or msg.get("text") or ""
172+
assistant_turns.append(content)
173+
elif msg.get("role") == "tool":
174+
pass # already counted has_tool_calling
175+
else:
176+
assistant_turns.append(str(msg))
177+
if not assistant_turns:
178+
previous_have_think = True
179+
else:
180+
previous = assistant_turns[:-1] # all but last
181+
previous_have_think = all(
182+
"<think>" in c and "</think>" in c for c in previous if c
183+
)
184+
fmt = has_answer_tags and has_tool_calling and previous_have_think
185+
return fmt, previous_have_think, has_tool_calling
186+
187+
188+
@reward(name="qa_em_format_reward")
189+
def qa_em_format_reward(final_response: str, golden_answers: List[str], trajectory: List[str]) -> float:
190+
"""
191+
Calculate the reward for the agent's response based on the EM score.
192+
193+
- 1.0 if the format is correct, and the em is true
194+
- 0.1 if the format is correct, but the em is wrong
195+
- 0.0 if the format is incorrect
196+
"""
197+
predicted = _extract_answer_tag(final_response)
198+
if not golden_answers:
199+
max_em, max_f1 = 0.0, 0.0
200+
else:
201+
max_em = max(em_score(predicted, g) for g in golden_answers)
202+
max_f1 = max(f1_score(predicted, g)[0] for g in golden_answers)
203+
fmt, previous_have_think, has_tool_calling = _format_ok(final_response, trajectory)
204+
205+
reward = 0.0
206+
if fmt and max_em:
207+
reward = 1.0
208+
elif fmt and not max_em:
209+
reward = 0.1
210+
elif max_em and not fmt:
211+
reward = 0.0
212+
213+
return {
214+
"reward": reward,
215+
"em": max_em,
216+
"f1": max_f1,
217+
"fmt": 1.0 if fmt else 0.0,
218+
"fmt_think": 1.0 if previous_have_think else 0.0,
219+
"fmt_tool": 1.0 if has_tool_calling else 0.0,
220+
}
221+
222+
223+
152224
@reward(name="ok_vqa_reward")
153225
def ok_vqa_reward(
154226
final_response: str, answers: List[str], trajectory: List[str]

src/agentfly/tools/src/search/async_dense_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import deque
88
from concurrent.futures import ThreadPoolExecutor
99
from functools import lru_cache
10-
10+
from .... import AGENT_CACHE_DIR
1111
import datasets
1212
import numpy as np
1313
import torch

0 commit comments

Comments
 (0)