Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions tests/utils/test_chat_template_on_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates

import pytest

from verl.utils.chat_template import extract_system_prompt_and_generation, initialize_system_prompt


class AppendOnlyTokenizer:
prefix = [1, 2]
user_turn = [10, 11]
assistant_turn = [20]
generation_prompt = [30]

def apply_chat_template(self, messages, *, add_generation_prompt, tokenize, **kwargs):
assert tokenize
token_ids = list(self.prefix)
for message in messages:
if message["role"] == "user":
token_ids.extend(self.user_turn)
elif message["role"] == "assistant":
token_ids.extend(self.assistant_turn)
else:
raise ValueError(f"Unsupported role: {message['role']}")

if add_generation_prompt:
token_ids.extend(self.generation_prompt)
return token_ids


class AlternatingTokenizer(AppendOnlyTokenizer):
prefix = [101]
user_turn = [110]
assistant_turn = [120]

def apply_chat_template(self, messages, *, add_generation_prompt, tokenize, **kwargs):
for prev_message, message in zip(messages, messages[1:], strict=False):
if prev_message["role"] == message["role"]:
raise ValueError("Conversation roles must alternate")
return super().apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
tokenize=tokenize,
**kwargs,
)


class AppendOnlyWithFinalTokenTokenizer(AppendOnlyTokenizer):
final_token = [99]

def apply_chat_template(self, messages, *, add_generation_prompt, tokenize, **kwargs):
token_ids = super().apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=tokenize,
**kwargs,
)
if add_generation_prompt:
token_ids.extend(self.generation_prompt)
else:
token_ids.extend(self.final_token)
return token_ids


class ConversationFinalTokenTokenizer:
user_turn = [210, 211]
assistant_turn = [220, 221]
final_token = [299]
generation_prompt = [230]

def apply_chat_template(self, messages, *, add_generation_prompt, tokenize, **kwargs):
assert tokenize
token_ids = []
for message in messages:
if message["role"] == "user":
token_ids.extend(self.user_turn)
elif message["role"] == "assistant":
token_ids.extend(self.assistant_turn)
else:
raise ValueError(f"Unsupported role: {message['role']}")

if add_generation_prompt:
token_ids.extend(self.generation_prompt)
else:
token_ids.extend(self.final_token)
return token_ids


def test_initialize_system_prompt_infers_append_only_prefix():
assert initialize_system_prompt(AppendOnlyTokenizer()) == AppendOnlyTokenizer.prefix


def test_extract_system_prompt_and_generation_uses_append_only_prefix():
system_prompt, generation_prompt = extract_system_prompt_and_generation(AppendOnlyTokenizer())

assert system_prompt == AppendOnlyTokenizer.prefix
assert generation_prompt == AppendOnlyTokenizer.generation_prompt


def test_initialize_system_prompt_supports_alternating_role_templates():
assert initialize_system_prompt(AlternatingTokenizer()) == AlternatingTokenizer.prefix


def test_initialize_system_prompt_handles_common_final_tokens():
assert initialize_system_prompt(AppendOnlyWithFinalTokenTokenizer()) == AppendOnlyWithFinalTokenTokenizer.prefix


@pytest.mark.parametrize("tokenizer_cls", [AppendOnlyWithFinalTokenTokenizer, ConversationFinalTokenTokenizer])
def test_extract_generation_prompt_handles_replaced_final_tokens(tokenizer_cls):
_, generation_prompt = extract_system_prompt_and_generation(tokenizer_cls())

assert generation_prompt == tokenizer_cls.generation_prompt


def test_extract_system_prompt_and_generation_supports_alternating_role_templates():
system_prompt, generation_prompt = extract_system_prompt_and_generation(AlternatingTokenizer())

assert system_prompt == AlternatingTokenizer.prefix
assert generation_prompt == AlternatingTokenizer.generation_prompt


@pytest.mark.parametrize("helper", [initialize_system_prompt, extract_system_prompt_and_generation])
def test_system_prompt_inference_ignores_non_append_only_templates(helper):
result = helper(ConversationFinalTokenTokenizer())
system_prompt = result[0] if isinstance(result, tuple) else result

assert system_prompt == []
162 changes: 129 additions & 33 deletions verl/utils/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,117 @@
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


def _apply_chat_template_token_ids(
tokenizer, messages: list[dict], *, add_generation_prompt: bool, **kwargs
) -> list[int]:
return normalize_token_ids(
tokenizer.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
tokenize=True,
**kwargs,
)
)


def _common_suffix_len(*token_lists: list[int]) -> int:
suffix_len = 0
for tokens in zip(*(reversed(token_ids) for token_ids in token_lists), strict=False):
if any(token != tokens[0] for token in tokens[1:]):
break
suffix_len += 1
return suffix_len
Comment thread
anzhsoft marked this conversation as resolved.


def _common_prefix_len(first: list[int], second: list[int]) -> int:
prefix_len = 0
for first_token, second_token in zip(first, second, strict=False):
if first_token != second_token:
break
prefix_len += 1
return prefix_len
Comment thread
anzhsoft marked this conversation as resolved.


def _remove_suffix(token_ids: list[int], suffix_len: int) -> list[int]:
if suffix_len == 0:
return token_ids
return token_ids[:-suffix_len]


def _infer_prefix_from_appended_turn_core(
first_turn: list[int], base: list[int], extended: list[int]
) -> list[int] | None:
if len(extended) <= len(base) or extended[: len(base)] != base:
return None

appended_turn = extended[len(base) :]
if len(appended_turn) > len(first_turn) or first_turn[-len(appended_turn) :] != appended_turn:
return None

return first_turn[: -len(appended_turn)]


def _infer_prefix_from_appended_turn(first_turn: list[int], base: list[int], extended: list[int]) -> list[int] | None:
suffix_lens = [0, *range(_common_suffix_len(first_turn, base, extended), 0, -1)]
for suffix_len in suffix_lens:
system_prompt = _infer_prefix_from_appended_turn_core(
_remove_suffix(first_turn, suffix_len),
_remove_suffix(base, suffix_len),
_remove_suffix(extended, suffix_len),
)
if system_prompt is not None:
return system_prompt

return None


def _extract_generation_prompt(no_generation: list[int], with_generation: list[int]) -> list[int]:
prefix_len = _common_prefix_len(no_generation, with_generation)
return with_generation[prefix_len:]


def _infer_system_prompt(tokenizer, token1: list[int], **apply_chat_template_kwargs) -> list[int]:
two_users = [{"role": "user", "content": ""}, {"role": "user", "content": ""}]
user_assistant = [{"role": "user", "content": ""}, {"role": "assistant", "content": ""}]
user_assistant_user = user_assistant + [{"role": "user", "content": ""}]

# Prefer the historical consecutive-user probe when the template supports it.
try:
token2 = _apply_chat_template_token_ids(
tokenizer,
two_users,
add_generation_prompt=False,
**apply_chat_template_kwargs,
)
system_prompt = _infer_prefix_from_appended_turn(token1, token1, token2)
if system_prompt is not None:
return system_prompt
except Exception:
logger.debug("Failed to render consecutive user messages for system prompt inference.", exc_info=True)

# Some official templates require alternating user/assistant roles.
try:
token2 = _apply_chat_template_token_ids(
tokenizer,
user_assistant,
add_generation_prompt=False,
**apply_chat_template_kwargs,
)
token3 = _apply_chat_template_token_ids(
tokenizer,
user_assistant_user,
add_generation_prompt=False,
**apply_chat_template_kwargs,
)
system_prompt = _infer_prefix_from_appended_turn(token1, token2, token3)
if system_prompt is not None:
return system_prompt
except Exception:
logger.debug("Failed to render alternating messages for system prompt inference.", exc_info=True)

return []


def initialize_system_prompt(tokenizer, **apply_chat_template_kwargs) -> list[int]:
"""
Initialize system prompt tokens for chat templates that support them.
Expand All @@ -21,47 +132,32 @@ def initialize_system_prompt(tokenizer, **apply_chat_template_kwargs) -> list[in
Returns:
List of token IDs for the system prompt, or empty list if not supported
"""
token1 = normalize_token_ids(
tokenizer.apply_chat_template(
[{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True, **apply_chat_template_kwargs
)
token1 = _apply_chat_template_token_ids(
tokenizer,
[{"role": "user", "content": ""}],
add_generation_prompt=False,
**apply_chat_template_kwargs,
)
token2 = normalize_token_ids(
tokenizer.apply_chat_template(
[{"role": "user", "content": ""}] * 2,
add_generation_prompt=False,
tokenize=True,
**apply_chat_template_kwargs,
)
)
# get system prompt tokens
system_prompt = token1[: -(len(token2) - len(token1))]
return system_prompt
return _infer_system_prompt(tokenizer, token1, **apply_chat_template_kwargs)


def extract_system_prompt_and_generation(tokenizer, **apply_chat_template_kwargs):
token1 = normalize_token_ids(
tokenizer.apply_chat_template(
[{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True, **apply_chat_template_kwargs
)
)
token2 = normalize_token_ids(
tokenizer.apply_chat_template(
[{"role": "user", "content": ""}] * 2,
add_generation_prompt=False,
tokenize=True,
**apply_chat_template_kwargs,
)
token1 = _apply_chat_template_token_ids(
tokenizer,
[{"role": "user", "content": ""}],
add_generation_prompt=False,
**apply_chat_template_kwargs,
)
# get system prompt tokens
system_prompt = token1[: -(len(token2) - len(token1))]
system_prompt = _infer_system_prompt(tokenizer, token1, **apply_chat_template_kwargs)
# get generate prompt tokens
token3 = normalize_token_ids(
tokenizer.apply_chat_template(
[{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True, **apply_chat_template_kwargs
)
token3 = _apply_chat_template_token_ids(
tokenizer,
[{"role": "user", "content": ""}],
add_generation_prompt=True,
**apply_chat_template_kwargs,
)
generate_prompt = token3[len(token1) :]
generate_prompt = _extract_generation_prompt(token1, token3)

return system_prompt, generate_prompt

Expand Down