diff --git a/tests/utils/test_chat_template_on_cpu.py b/tests/utils/test_chat_template_on_cpu.py new file mode 100644 index 00000000000..66434ef7d86 --- /dev/null +++ b/tests/utils/test_chat_template_on_cpu.py @@ -0,0 +1,126 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates + +import pytest + +from verl.utils.chat_template import extract_system_prompt_and_generation, initialize_system_prompt + + +class AppendOnlyTokenizer: + prefix = [1, 2] + user_turn = [10, 11] + assistant_turn = [20] + generation_prompt = [30] + + def apply_chat_template(self, messages, *, add_generation_prompt, tokenize, **kwargs): + assert tokenize + token_ids = list(self.prefix) + for message in messages: + if message["role"] == "user": + token_ids.extend(self.user_turn) + elif message["role"] == "assistant": + token_ids.extend(self.assistant_turn) + else: + raise ValueError(f"Unsupported role: {message['role']}") + + if add_generation_prompt: + token_ids.extend(self.generation_prompt) + return token_ids + + +class AlternatingTokenizer(AppendOnlyTokenizer): + prefix = [101] + user_turn = [110] + assistant_turn = [120] + + def apply_chat_template(self, messages, *, add_generation_prompt, tokenize, **kwargs): + for prev_message, message in zip(messages, messages[1:], strict=False): + if prev_message["role"] == message["role"]: + raise ValueError("Conversation roles must alternate") + return super().apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + **kwargs, + ) + + +class AppendOnlyWithFinalTokenTokenizer(AppendOnlyTokenizer): + final_token = [99] + + def apply_chat_template(self, messages, *, add_generation_prompt, tokenize, **kwargs): + token_ids = super().apply_chat_template( + messages, + add_generation_prompt=False, + tokenize=tokenize, + **kwargs, + ) + if add_generation_prompt: + token_ids.extend(self.generation_prompt) + else: + token_ids.extend(self.final_token) + return token_ids + + +class ConversationFinalTokenTokenizer: + user_turn = [210, 211] + assistant_turn = [220, 221] + final_token = [299] + generation_prompt = [230] + + def apply_chat_template(self, messages, *, add_generation_prompt, tokenize, **kwargs): + assert tokenize + token_ids = [] + for message in messages: + if message["role"] == "user": + token_ids.extend(self.user_turn) + elif message["role"] == "assistant": + token_ids.extend(self.assistant_turn) + else: + raise ValueError(f"Unsupported role: {message['role']}") + + if add_generation_prompt: + token_ids.extend(self.generation_prompt) + else: + token_ids.extend(self.final_token) + return token_ids + + +def test_initialize_system_prompt_infers_append_only_prefix(): + assert initialize_system_prompt(AppendOnlyTokenizer()) == AppendOnlyTokenizer.prefix + + +def test_extract_system_prompt_and_generation_uses_append_only_prefix(): + system_prompt, generation_prompt = extract_system_prompt_and_generation(AppendOnlyTokenizer()) + + assert system_prompt == AppendOnlyTokenizer.prefix + assert generation_prompt == AppendOnlyTokenizer.generation_prompt + + +def test_initialize_system_prompt_supports_alternating_role_templates(): + assert initialize_system_prompt(AlternatingTokenizer()) == AlternatingTokenizer.prefix + + +def test_initialize_system_prompt_handles_common_final_tokens(): + assert initialize_system_prompt(AppendOnlyWithFinalTokenTokenizer()) == AppendOnlyWithFinalTokenTokenizer.prefix + + +@pytest.mark.parametrize("tokenizer_cls", [AppendOnlyWithFinalTokenTokenizer, ConversationFinalTokenTokenizer]) +def test_extract_generation_prompt_handles_replaced_final_tokens(tokenizer_cls): + _, generation_prompt = extract_system_prompt_and_generation(tokenizer_cls()) + + assert generation_prompt == tokenizer_cls.generation_prompt + + +def test_extract_system_prompt_and_generation_supports_alternating_role_templates(): + system_prompt, generation_prompt = extract_system_prompt_and_generation(AlternatingTokenizer()) + + assert system_prompt == AlternatingTokenizer.prefix + assert generation_prompt == AlternatingTokenizer.generation_prompt + + +@pytest.mark.parametrize("helper", [initialize_system_prompt, extract_system_prompt_and_generation]) +def test_system_prompt_inference_ignores_non_append_only_templates(helper): + result = helper(ConversationFinalTokenTokenizer()) + system_prompt = result[0] if isinstance(result, tuple) else result + + assert system_prompt == [] diff --git a/verl/utils/chat_template.py b/verl/utils/chat_template.py index 7e5a1de3bc5..7a2ca910b4e 100644 --- a/verl/utils/chat_template.py +++ b/verl/utils/chat_template.py @@ -10,6 +10,117 @@ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +def _apply_chat_template_token_ids( + tokenizer, messages: list[dict], *, add_generation_prompt: bool, **kwargs +) -> list[int]: + return normalize_token_ids( + tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + **kwargs, + ) + ) + + +def _common_suffix_len(*token_lists: list[int]) -> int: + suffix_len = 0 + for tokens in zip(*(reversed(token_ids) for token_ids in token_lists), strict=False): + if any(token != tokens[0] for token in tokens[1:]): + break + suffix_len += 1 + return suffix_len + + +def _common_prefix_len(first: list[int], second: list[int]) -> int: + prefix_len = 0 + for first_token, second_token in zip(first, second, strict=False): + if first_token != second_token: + break + prefix_len += 1 + return prefix_len + + +def _remove_suffix(token_ids: list[int], suffix_len: int) -> list[int]: + if suffix_len == 0: + return token_ids + return token_ids[:-suffix_len] + + +def _infer_prefix_from_appended_turn_core( + first_turn: list[int], base: list[int], extended: list[int] +) -> list[int] | None: + if len(extended) <= len(base) or extended[: len(base)] != base: + return None + + appended_turn = extended[len(base) :] + if len(appended_turn) > len(first_turn) or first_turn[-len(appended_turn) :] != appended_turn: + return None + + return first_turn[: -len(appended_turn)] + + +def _infer_prefix_from_appended_turn(first_turn: list[int], base: list[int], extended: list[int]) -> list[int] | None: + suffix_lens = [0, *range(_common_suffix_len(first_turn, base, extended), 0, -1)] + for suffix_len in suffix_lens: + system_prompt = _infer_prefix_from_appended_turn_core( + _remove_suffix(first_turn, suffix_len), + _remove_suffix(base, suffix_len), + _remove_suffix(extended, suffix_len), + ) + if system_prompt is not None: + return system_prompt + + return None + + +def _extract_generation_prompt(no_generation: list[int], with_generation: list[int]) -> list[int]: + prefix_len = _common_prefix_len(no_generation, with_generation) + return with_generation[prefix_len:] + + +def _infer_system_prompt(tokenizer, token1: list[int], **apply_chat_template_kwargs) -> list[int]: + two_users = [{"role": "user", "content": ""}, {"role": "user", "content": ""}] + user_assistant = [{"role": "user", "content": ""}, {"role": "assistant", "content": ""}] + user_assistant_user = user_assistant + [{"role": "user", "content": ""}] + + # Prefer the historical consecutive-user probe when the template supports it. + try: + token2 = _apply_chat_template_token_ids( + tokenizer, + two_users, + add_generation_prompt=False, + **apply_chat_template_kwargs, + ) + system_prompt = _infer_prefix_from_appended_turn(token1, token1, token2) + if system_prompt is not None: + return system_prompt + except Exception: + logger.debug("Failed to render consecutive user messages for system prompt inference.", exc_info=True) + + # Some official templates require alternating user/assistant roles. + try: + token2 = _apply_chat_template_token_ids( + tokenizer, + user_assistant, + add_generation_prompt=False, + **apply_chat_template_kwargs, + ) + token3 = _apply_chat_template_token_ids( + tokenizer, + user_assistant_user, + add_generation_prompt=False, + **apply_chat_template_kwargs, + ) + system_prompt = _infer_prefix_from_appended_turn(token1, token2, token3) + if system_prompt is not None: + return system_prompt + except Exception: + logger.debug("Failed to render alternating messages for system prompt inference.", exc_info=True) + + return [] + + def initialize_system_prompt(tokenizer, **apply_chat_template_kwargs) -> list[int]: """ Initialize system prompt tokens for chat templates that support them. @@ -21,47 +132,32 @@ def initialize_system_prompt(tokenizer, **apply_chat_template_kwargs) -> list[in Returns: List of token IDs for the system prompt, or empty list if not supported """ - token1 = normalize_token_ids( - tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True, **apply_chat_template_kwargs - ) + token1 = _apply_chat_template_token_ids( + tokenizer, + [{"role": "user", "content": ""}], + add_generation_prompt=False, + **apply_chat_template_kwargs, ) - token2 = normalize_token_ids( - tokenizer.apply_chat_template( - [{"role": "user", "content": ""}] * 2, - add_generation_prompt=False, - tokenize=True, - **apply_chat_template_kwargs, - ) - ) - # get system prompt tokens - system_prompt = token1[: -(len(token2) - len(token1))] - return system_prompt + return _infer_system_prompt(tokenizer, token1, **apply_chat_template_kwargs) def extract_system_prompt_and_generation(tokenizer, **apply_chat_template_kwargs): - token1 = normalize_token_ids( - tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True, **apply_chat_template_kwargs - ) - ) - token2 = normalize_token_ids( - tokenizer.apply_chat_template( - [{"role": "user", "content": ""}] * 2, - add_generation_prompt=False, - tokenize=True, - **apply_chat_template_kwargs, - ) + token1 = _apply_chat_template_token_ids( + tokenizer, + [{"role": "user", "content": ""}], + add_generation_prompt=False, + **apply_chat_template_kwargs, ) # get system prompt tokens - system_prompt = token1[: -(len(token2) - len(token1))] + system_prompt = _infer_system_prompt(tokenizer, token1, **apply_chat_template_kwargs) # get generate prompt tokens - token3 = normalize_token_ids( - tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True, **apply_chat_template_kwargs - ) + token3 = _apply_chat_template_token_ids( + tokenizer, + [{"role": "user", "content": ""}], + add_generation_prompt=True, + **apply_chat_template_kwargs, ) - generate_prompt = token3[len(token1) :] + generate_prompt = _extract_generation_prompt(token1, token3) return system_prompt, generate_prompt