Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
114a15a
fix(pu): add threading.Lock() in retrieval_server.py
puyuan1996 Nov 18, 2025
695a7b4
feature(pu): add init version of off-policy grpo/ppo
puyuan1996 Dec 9, 2025
2e267e3
fix(pu): fix some data bugs such as policy_versions and proximal_log_…
puyuan1996 Dec 9, 2025
788f5a7
fix(pu): fix samples compatibility
puyuan1996 Dec 9, 2025
7ba62a8
feature(pu): add init version of off-policy grpo with http-buffer
puyuan1996 Dec 9, 2025
e0ff3d3
polish(pu): add buffer sampling buffer_sampling_strategies
puyuan1996 Dec 11, 2025
f6d901a
fix(pu): fix reward based priority sampling strategy
puyuan1996 Dec 16, 2025
47d0507
fix(pu): fix buffer add/get sample related logics
puyuan1996 Dec 16, 2025
15c5d4b
tmp
puyuan1996 Dec 16, 2025
154e786
fix(pu): fix buffer samling strategies
puyuan1996 Dec 23, 2025
011ec2a
fix(pu): add LIFO strategy (newest-first sampling), fix exhausted_groups
Jan 6, 2026
0341509
fix(pu): fix some stability issues
puyuan1996 Jan 7, 2026
dd1ee81
fix(pu): make sure rollout ALWAYS generates new data, regardless of b…
Jan 7, 2026
3a1f4bb
polish(pu): comment some debug info
puyuan1996 Jan 7, 2026
86ddd9b
feature(pu): add train-iters-per-rollout option
puyuan1996 Jan 8, 2026
f4ada18
fix(pu): add offload_rollout before each training
puyuan1996 Jan 8, 2026
10bcab8
fix(pu): fix train-iters-per-rollout option in train.py
puyuan1996 Jan 8, 2026
beaa274
polish(pu): add generated rewards in wandb and logs
puyuan1996 Jan 15, 2026
ee1cfd2
fix(pu): fix wandb_logging
puyuan1996 Jan 15, 2026
1aa14f6
polish(pu):polish bash, add gybrid sampling strategy
Jan 15, 2026
f5e12ce
polish(pu): add enable_format_reward option
puyuan1996 Jan 15, 2026
812dd38
feature(pu): add enable_m2po_filtering and use_proximal_logp_approxim…
puyuan1996 Jan 15, 2026
a4f18c6
polish(pu): polish enable_m2po_filtering and use_proximal_logp_approx…
puyuan1996 Jan 16, 2026
d13946e
fix(pu): fix m2po_filtering
puyuan1996 Jan 22, 2026
6aac241
fix(pu): fix wandb logging
puyuan1996 Jan 22, 2026
f83e4a5
fix(pu): fix m2po_filtering
puyuan1996 Jan 22, 2026
9a99f5f
fix(pu): fix wandb_utils
puyuan1996 Jan 23, 2026
0398d0b
fix(pu): fix proximal_logprob=old_logprob bug
puyuan1996 Jan 23, 2026
64b7a8b
fix(pu): fix wandb logging
puyuan1996 Jan 23, 2026
8182af2
fix(pu): fix train_batch/raw_reward and rollout/reward in wandb
puyuan1996 Jan 23, 2026
1de3b2f
test(pu): test train_iters_per_rollout>2
puyuan1996 Jan 27, 2026
7c200ea
fix(pu): fix update_weights when train_iters_per_rollout>1
puyuan1996 Jan 27, 2026
982609f
feature(pu): add baseline bash (original ppo loss with buffer)
puyuan1996 Jan 29, 2026
dc2c9f8
fix(pu): fix buffer enable bug when setting as ppo baseline with buffer
puyuan1996 Jan 30, 2026
6fa6f92
fix(pu): fix normalization bug in m2po mask, add behav-imp-weight-cap
puyuan1996 Feb 5, 2026
8043696
test(pu): test stale8
puyuan1996 Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions examples/search-r1/generate_with_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,61 @@ async def reward_func(args, sample, **kwargs):
"""The reward function for retrieval-based question answering.

Args:
args: the arguments
args: the arguments (now includes format reward settings)
sample: the sample to evaluate

Returns:
Reward score based on answer correctness and optionally format quality
"""
if not isinstance(sample, Sample):
raise TypeError("Sample must be an instance of Sample class.")

score = compute_score_em(
solution_str=sample.prompt + sample.response,
ground_truth=sample.label["ground_truth"],
format_score=SEARCH_R1_CONFIGS["format_score"],
)
# Check if format reward is enabled via command-line args
enable_format_reward = getattr(args, 'enable_format_reward', False)

# === Compute format validation details ===
from qa_em_format import is_valid_sequence, extract_solution, em_check, is_retrieval_correct

solution_str = sample.prompt + sample.response
ground_truth = sample.label["ground_truth"]

# Validate format and extract components
is_valid_format, format_reason = is_valid_sequence(solution_str)
answer = extract_solution(solution_str)
retrieval_correct = is_retrieval_correct(solution_str, ground_truth["target"]) if is_valid_format else False
answer_correct = em_check(answer, ground_truth["target"]) if answer else False

# Store format validation details in metadata for monitoring
if enable_format_reward:
sample.metadata['format_validation'] = {
'is_valid_format': is_valid_format,
'format_reason': format_reason,
'has_answer': answer is not None,
'answer_correct': answer_correct,
'retrieval_correct': retrieval_correct,
}

# === Compute reward score ===
if enable_format_reward:
# Use fine-grained format reward from command-line args
score = compute_score_em(
solution_str=solution_str,
ground_truth=ground_truth,
structure_format_score=getattr(args, 'structure_format_score', 0.2),
retrieval_score=getattr(args, 'retrieval_score', 0.1),
final_format_score=getattr(args, 'final_format_score', 0.1),
score=1.0,
)
else:
# Backward compatible: use default behavior (no format reward)
# This matches the original buggy behavior for comparison
score = compute_score_em(
solution_str=solution_str,
ground_truth=ground_truth,
structure_format_score=0, # No format reward
retrieval_score=0, # No retrieval reward
final_format_score=0, # No attempt reward
score=1.0,
)

return score
68 changes: 22 additions & 46 deletions examples/search-r1/local_dense_retriever/retrieval_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
import warnings
from typing import Optional
import threading # 1. [修改] 导入 threading 模块

import datasets
import faiss
Expand Down Expand Up @@ -76,7 +77,6 @@ def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16)

@torch.no_grad()
def encode(self, query_list: list[str], is_query=True) -> np.ndarray:
# processing query for different encoders
if isinstance(query_list, str):
query_list = [query_list]

Expand All @@ -98,7 +98,6 @@ def encode(self, query_list: list[str], is_query=True) -> np.ndarray:
inputs = {k: v.cuda() for k, v in inputs.items()}

if "T5" in type(self.model).__name__:
# T5-based retrieval model
decoder_input_ids = torch.zeros((inputs["input_ids"].shape[0], 1), dtype=torch.long).to(
inputs["input_ids"].device
)
Expand Down Expand Up @@ -144,6 +143,7 @@ def batch_search(self, query_list: list[str], num: int = None, return_score: boo


class BM25Retriever(BaseRetriever):
# BM25Retriever 不涉及 GPU,无需修改
def __init__(self, config):
super().__init__(config)
from pyserini.search.lucene import LuceneSearcher
Expand Down Expand Up @@ -202,6 +202,8 @@ def _batch_search(self, query_list: list[str], num: int = None, return_score: bo
else:
return results

# 2. [修改] 创建一个全局线程锁,用于保护对 GPU 资源的访问
gpu_lock = threading.Lock()

class DenseRetriever(BaseRetriever):
def __init__(self, config):
Expand All @@ -227,8 +229,12 @@ def __init__(self, config):
def _search(self, query: str, num: int = None, return_score: bool = False):
if num is None:
num = self.topk
query_emb = self.encoder.encode(query)
scores, idxs = self.index.search(query_emb, k=num)

# 3. [修改] 使用 with 语句块包裹 GPU 操作,自动获取和释放锁
with gpu_lock:
query_emb = self.encoder.encode(query)
scores, idxs = self.index.search(query_emb, k=num)

idxs = idxs[0]
scores = scores[0]
results = load_docs(self.corpus, idxs)
Expand All @@ -247,15 +253,18 @@ def _batch_search(self, query_list: list[str], num: int = None, return_score: bo
scores = []
for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc="Retrieval process: "):
query_batch = query_list[start_idx : start_idx + self.batch_size]
batch_emb = self.encoder.encode(query_batch)
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)

# 4. [修改] 在循环内部,仅对 GPU 密集型操作加锁
with gpu_lock:
batch_emb = self.encoder.encode(query_batch)
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)

# 在锁释放后处理数据,以最小化锁的持有时间
batch_scores = batch_scores.tolist()
batch_idxs = batch_idxs.tolist()

# load_docs is not vectorized, but is a python list approach
flat_idxs = sum(batch_idxs, [])
batch_results = load_docs(self.corpus, flat_idxs)
# chunk them back
batch_results = [batch_results[i * num : (i + 1) * num] for i in range(len(batch_idxs))]

results.extend(batch_results)
Expand Down Expand Up @@ -283,11 +292,6 @@ def get_retriever(config):


class Config:
"""
Minimal config class (simulating your argparse)
Replace this with your real arguments or load them dynamically.
"""

def __init__(
self,
retrieval_method: str = "bm25",
Expand Down Expand Up @@ -328,33 +332,9 @@ class QueryRequest(BaseModel):

@app.post("/retrieve")
def retrieve_endpoint(request: QueryRequest):
"""
Endpoint that accepts queries and performs retrieval.

Input format:
{
"queries": ["What is Python?", "Tell me about neural networks."],
"topk": 3,
"return_scores": true
}

Output format (when return_scores=True,similarity scores are returned):
{
"result": [
[ # Results for each query
{
{"document": doc, "score": score}
},
# ... more documents
],
# ... results for other queries
]
}
"""
if not request.topk:
request.topk = config.retrieval_topk # fallback to default
request.topk = config.retrieval_topk

# Perform batch retrieval
tmp = retriever.batch_search(query_list=request.queries, num=request.topk, return_score=request.return_scores)

scores = []
Expand All @@ -363,11 +343,9 @@ def retrieve_endpoint(request: QueryRequest):
except:
results = tmp

# Format response
resp = []
for i, single_result in enumerate(results):
if scores:
# If scores are returned, combine them with results
combined = []
for doc, score in zip(single_result, scores[i], strict=True):
combined.append({"document": doc, "score": score})
Expand Down Expand Up @@ -400,10 +378,8 @@ def retrieve_endpoint(request: QueryRequest):

args = parser.parse_args()

# 1) Build a config (could also parse from arguments).
# In real usage, you'd parse your CLI arguments or environment variables.
config = Config(
retrieval_method=args.retriever_name, # or "dense"
retrieval_method=args.retriever_name,
index_path=args.index_path,
corpus_path=args.corpus_path,
retrieval_topk=args.topk,
Expand All @@ -415,8 +391,8 @@ def retrieve_endpoint(request: QueryRequest):
retrieval_batch_size=512,
)

# 2) Instantiate a global retriever so it is loaded once and reused.
retriever = get_retriever(config)

# 3) Launch the server. By default, it listens on http://127.0.0.1:8000
uvicorn.run(app, host="0.0.0.0", port=8000)
# 5. [修改] 现在代码是线程安全的,可以安全地移除 `workers=1` 限制,
# 允许 uvicorn 使用默认的多个工作线程来提高并发处理能力。
uvicorn.run(app, host="0.0.0.0", port=8000)
Loading