diff --git a/CHANGELOG.md b/CHANGELOG.md index 01f16052..69873548 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to - ✅(front) add tests for SourceItem component - ✨(documents) make document context hybrid - 🐛(fix) add prevent_url_hallucination instruction to ConversationAgent +- ✨(back) add output token limit per message ### Changed diff --git a/docs/env.md b/docs/env.md index 2ae11091..f478c8a5 100644 --- a/docs/env.md +++ b/docs/env.md @@ -83,6 +83,7 @@ These are the environment variables you can set for the `conversations-backend` | AI_MODEL | AI Model name to use (used in default LLM configuration, not for production use) | | | AI_AGENT_INSTRUCTIONS | Base instruction for the AI agent (used in default LLM configuration, not for production use) | You are a helpful assistant. Wrap formulas... | | AI_AGENT_TOOLS | List of enabled tools for the agent (used in default LLM configuration, not for production use) | [] | +| LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE | Maximum number of output tokens the LLM can generate per message. When reached, the response is truncated and the user is notified. | 8192 | | CONVERSION_API_ENDPOINT | Conversion API endpoint | convert-markdown | | CONVERSION_API_CONTENT_FIELD | Conversion api content field | content | | CONVERSION_API_TIMEOUT | Conversion api timeout | 30 | diff --git a/src/backend/chat/clients/pydantic_ai.py b/src/backend/chat/clients/pydantic_ai.py index b946fb22..689a08fa 100644 --- a/src/backend/chat/clients/pydantic_ai.py +++ b/src/backend/chat/clients/pydantic_ai.py @@ -99,6 +99,7 @@ from asgiref.sync import sync_to_async from langfuse import get_client from pydantic_ai import Agent, InstrumentationSettings, RunContext, RunUsage +from pydantic_ai.capabilities import Hooks from pydantic_ai.messages import ( BinaryContent, DocumentUrl, @@ -123,6 +124,7 @@ UserPromptPart, ) from pydantic_ai.models import Model, infer_model_profile +from pydantic_ai.settings import ModelSettings from core.feature_flags.helpers import is_feature_enabled @@ -302,6 +304,15 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument self._is_smart_search_enabled = user.allow_smart_web_search self._fake_streaming_delay = settings.FAKE_STREAMING_DELAY + self._last_finish_reason: str | None = None + + self._truncation_hooks = Hooks() + + @self._truncation_hooks.on.after_model_request + async def _detect_truncation(_ctx, *, request_context, response): # pylint: disable=unused-argument + self._last_finish_reason = response.finish_reason + return response + self._context_deps = ContextDeps( conversation=conversation, user=user, @@ -321,6 +332,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument if self._langfuse_available else False, deps_type=ContextDeps, + capabilities=[self._truncation_hooks], ) add_document_rag_search_tool_from_setting(self.conversation_agent, self.user) @@ -419,6 +431,7 @@ async def _clean(self): It can be used to release resources or perform any necessary cleanup. """ self._last_stop_check = 0 + self._last_finish_reason = None await cache.adelete(self._stop_cache_key) # --------------------------------------------------------------------- # @@ -1122,6 +1135,10 @@ async def _finalize_conversation( # pylint: disable=too-many-arguments,too-many ] ) self._update_langfuse_trace(run_output) + + if self._last_finish_reason == "length": + yield events_v4.MessageAnnotationPart(annotations=[{"truncated": True}]) + # Vercel finish message yield events_v4.FinishMessagePart( finish_reason=events_v4.FinishReason.STOP, @@ -1196,11 +1213,14 @@ async def _run_agent( # pylint: disable=too-many-locals if history[-1].parts and history[-1].parts[-1].part_kind == "tool-return": history.append(ModelResponse(parts=[TextPart(content="ok")], kind="response")) + model_settings = ModelSettings(max_tokens=settings.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE) + async with self.conversation_agent.iter( [user_prompt] + input_images, message_history=history, # history will pass through agent's history_processors deps=self._context_deps, toolsets=mcp_servers, + model_settings=model_settings, ) as run: state = StreamingState() async for event in self._process_agent_nodes(run, state, langfuse): @@ -1278,6 +1298,11 @@ def _prepare_update_conversation( _output_ui_message.annotations = [] _output_ui_message.annotations.append({"co2_impact": co2_impact}) + if self._last_finish_reason == "length": + if _output_ui_message.annotations is None: + _output_ui_message.annotations = [] + _output_ui_message.annotations.append({"truncated": True}) + usage["co2_impact"] += self.conversation.agent_usage.get("co2_impact", 0) usage["promptTokens"] += self.conversation.agent_usage.get("promptTokens", 0) usage["completionTokens"] += self.conversation.agent_usage.get("completionTokens", 0) diff --git a/src/backend/chat/tests/clients/pydantic_ai/test_output_token_limit.py b/src/backend/chat/tests/clients/pydantic_ai/test_output_token_limit.py new file mode 100644 index 00000000..f0581531 --- /dev/null +++ b/src/backend/chat/tests/clients/pydantic_ai/test_output_token_limit.py @@ -0,0 +1,201 @@ +"""Tests for the output token limit feature.""" + +from unittest.mock import MagicMock, patch + +from django.conf import settings as django_settings + +import pytest +from asgiref.sync import sync_to_async +from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart +from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.settings import ModelSettings + +from chat.ai_sdk_types import TextUIPart, UIMessage +from chat.clients.pydantic_ai import AIAgentService +from chat.factories import ChatConversationFactory + +pytestmark = pytest.mark.django_db() + + +@pytest.fixture(autouse=True) +def base_settings(settings): + """Set up base settings for all tests in this module.""" + settings.AI_BASE_URL = "https://api.llm.com/v1/" + settings.AI_API_KEY = "test-key" + settings.AI_MODEL = "model-123" + settings.AI_AGENT_INSTRUCTIONS = "You are a helpful assistant" + settings.AI_AGENT_TOOLS = [] + settings.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE = 1000 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(name="ui_messages") +def ui_messages_fixture(): + """Fixture providing a single user UIMessage.""" + return [ + UIMessage( + id="msg-1", + role="user", + content="Hello", + parts=[TextUIPart(type="text", text="Hello")], + ) + ] + + +@pytest.fixture(name="simple_model") +def simple_model_fixture(): + """Fixture providing a simple streaming FunctionModel.""" + + async def _model(_messages: list[ModelMessage], _info: AgentInfo): + yield "Hello world" + + return FunctionModel(stream_function=_model) + + +# --------------------------------------------------------------------------- +# Settings tests +# --------------------------------------------------------------------------- + + +def test_llm_max_output_tokens_per_message_setting_exists(): + """Setting must exist with a positive integer value.""" + assert hasattr(django_settings, "LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE") + assert isinstance(django_settings.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE, int) + assert django_settings.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE == 1000 + + +# --------------------------------------------------------------------------- +# Hook / flag tests +# --------------------------------------------------------------------------- + + +def test_response_truncated_flag_initialized_to_false(): + """AIAgentService must start with _last_finish_reason=None.""" + conversation = ChatConversationFactory() + service = AIAgentService(conversation, user=conversation.owner) + assert service._last_finish_reason is None # pylint: disable=protected-access + + +@pytest.mark.asyncio +async def test_response_truncated_flag_reset_by_clean(): + """_clean() must reset _last_finish_reason to None.""" + conversation = await sync_to_async(ChatConversationFactory)() + service = AIAgentService(conversation, user=conversation.owner) + service._last_finish_reason = "length" # pylint: disable=protected-access + await service._clean() # pylint: disable=protected-access + assert service._last_finish_reason is None # pylint: disable=protected-access + + +@pytest.mark.asyncio +async def test_hook_sets_truncated_flag_when_finish_reason_is_length(): + """after_model_request hook must set _last_finish_reason='length' for finish_reason='length'.""" + conversation = await sync_to_async(ChatConversationFactory)() + service = AIAgentService(conversation, user=conversation.owner) + + mock_response = ModelResponse( + parts=[TextPart(content="Hello wo...")], + finish_reason="length", + ) + mock_ctx = MagicMock() + mock_request_context = MagicMock() + + result = await service.conversation_agent._root_capability.after_model_request( # pylint: disable=protected-access + mock_ctx, request_context=mock_request_context, response=mock_response + ) + + assert service._last_finish_reason == "length" # pylint: disable=protected-access + assert result == mock_response + + +@pytest.mark.asyncio +async def test_hook_does_not_set_truncated_flag_when_finish_reason_is_stop(): + """after_model_request hook must set _last_finish_reason='stop' for finish_reason='stop'.""" + conversation = await sync_to_async(ChatConversationFactory)() + service = AIAgentService(conversation, user=conversation.owner) + + mock_response = ModelResponse( + parts=[TextPart(content="Hello world")], + finish_reason="stop", + ) + mock_ctx = MagicMock() + mock_request_context = MagicMock() + + await service.conversation_agent._root_capability.after_model_request( # pylint: disable=protected-access + mock_ctx, request_context=mock_request_context, response=mock_response + ) + + assert service._last_finish_reason == "stop" # pylint: disable=protected-access + + +# --------------------------------------------------------------------------- +# Stream event tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_run_agent_passes_max_tokens_model_settings(ui_messages): + """_run_agent must pass + ModelSettings(max_tokens=LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE) to agent.iter().""" + conversation = await sync_to_async(ChatConversationFactory)() + service = AIAgentService(conversation, user=conversation.owner) + + captured_kwargs = {} + original_iter = service.conversation_agent.iter + + def capturing_iter(*args, **kwargs): + captured_kwargs.update(kwargs) + return original_iter(*args, **kwargs) + + async def simple_model(_messages: list[ModelMessage], _info: AgentInfo): + yield "ok" + + with service.conversation_agent.override(model=FunctionModel(stream_function=simple_model)): + with patch.object(service.conversation_agent, "iter", side_effect=capturing_iter): + async for _ in service.stream_data_async(ui_messages): + pass + + assert "model_settings" in captured_kwargs + assert captured_kwargs["model_settings"] == ModelSettings(max_tokens=1000) + + +@pytest.mark.asyncio +async def test_truncation_annotation_emitted_when_flag_is_set(ui_messages): + """When the model returns finish_reason='length', stream must emit + MessageAnnotationPart with truncated=True.""" + conversation = await sync_to_async(ChatConversationFactory)() + service = AIAgentService(conversation, user=conversation.owner) + service._support_streaming = False # pylint: disable=protected-access + + def _length_model(_messages: list[ModelMessage], _info: AgentInfo): + return ModelResponse(parts=[TextPart(content="Hello world")], finish_reason="length") + + length_model = FunctionModel(function=_length_model) + + chunks = [] + with service.conversation_agent.override(model=length_model): + async for chunk in service.stream_data_async(ui_messages): + chunks.append(chunk) + + # The MessageAnnotationPart is encoded as a line containing '"truncated"' + annotation_chunks = [c for c in chunks if '"truncated"' in c] + assert len(annotation_chunks) == 1 + assert '"truncated"' in annotation_chunks[0] + + +@pytest.mark.asyncio +async def test_truncation_annotation_not_emitted_when_flag_is_false(ui_messages, simple_model): + """When _last_finish_reason is not 'length', stream must NOT emit a truncation annotation.""" + conversation = await sync_to_async(ChatConversationFactory)() + service = AIAgentService(conversation, user=conversation.owner) + + chunks = [] + with service.conversation_agent.override(model=simple_model): + async for chunk in service.stream_data_async(ui_messages): + chunks.append(chunk) + + annotation_chunks = [c for c in chunks if '"truncated"' in c] + assert len(annotation_chunks) == 0 diff --git a/src/backend/chat/tests/test_ai_agent_service_co2.py b/src/backend/chat/tests/test_ai_agent_service_co2.py index 80378658..902556b2 100644 --- a/src/backend/chat/tests/test_ai_agent_service_co2.py +++ b/src/backend/chat/tests/test_ai_agent_service_co2.py @@ -35,6 +35,7 @@ def service_fixture(conversation): """Instantiate AIAgentService without __init__, injecting only the conversation.""" s = object.__new__(AIAgentService) s.conversation = conversation + s._last_finish_reason = None return s diff --git a/src/backend/conversations/settings.py b/src/backend/conversations/settings.py index 1b5c0d3e..fd37c4ba 100755 --- a/src/backend/conversations/settings.py +++ b/src/backend/conversations/settings.py @@ -678,6 +678,11 @@ class Base(BraveSettings, Configuration): environ_name="DEFAULT_ALLOW_CONVERSATION_ANALYTICS", environ_prefix=None, ) + LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE = values.IntegerValue( + 8192, + environ_name="LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE", + environ_prefix=None, + ) DEFAULT_ALLOW_SMART_WEB_SEARCH = values.BooleanValue( default=False, environ_name="DEFAULT_ALLOW_SMART_WEB_SEARCH", diff --git a/src/frontend/apps/conversations/src/features/chat/components/MessageItem.tsx b/src/frontend/apps/conversations/src/features/chat/components/MessageItem.tsx index 0bd1f5ad..7ea130f1 100644 --- a/src/frontend/apps/conversations/src/features/chat/components/MessageItem.tsx +++ b/src/frontend/apps/conversations/src/features/chat/components/MessageItem.tsx @@ -12,6 +12,7 @@ import { } from '@/features/chat/components/MessageBlock'; import { SourceItemList } from '@/features/chat/components/SourceItemList'; import { ToolInvocationItem } from '@/features/chat/components/ToolInvocationItem'; +import { TruncatedResponseMessage } from '@/features/chat/components/TruncatedResponseMessage'; // Memoized blocks list to prevent parent re-renders from causing block remounts const BlocksList = React.memo( @@ -242,6 +243,19 @@ const MessageItemComponent: React.FC = ({ ); }, [toolInvocationParts]); + const isTruncated = React.useMemo(() => { + return ( + message.role === 'assistant' && + Array.isArray(message.annotations) && + message.annotations.some( + (a) => + typeof a === 'object' && + a !== null && + (a as Record).truncated === true, + ) + ); + }, [message.annotations, message.role]); + const activeToolInvocation = React.useMemo(() => { const tool = toolInvocationParts.find( (part) => part.toolInvocation.toolName !== 'document_parsing', @@ -478,6 +492,7 @@ const MessageItemComponent: React.FC = ({ )} + {isTruncated && } @@ -510,6 +525,23 @@ const arePropsEqual = ( return false; } + // Check annotations changes (for truncation and other metadata) + const prevAnnotationsLength = prevProps.message.annotations?.length ?? 0; + const nextAnnotationsLength = nextProps.message.annotations?.length ?? 0; + if (prevAnnotationsLength !== nextAnnotationsLength) { + return false; + } + const isTruncated = (annotations: typeof prevProps.message.annotations) => + annotations?.some( + (a) => + typeof a === 'object' && + a !== null && + (a as Record).truncated === true, + ) ?? false; + if (isTruncated(prevProps.message.annotations) !== isTruncated(nextProps.message.annotations)) { + return false; + } + // Check attachments const prevAttachmentsLength = prevProps.message.experimental_attachments?.length ?? 0; diff --git a/src/frontend/apps/conversations/src/features/chat/components/TruncatedResponseMessage.tsx b/src/frontend/apps/conversations/src/features/chat/components/TruncatedResponseMessage.tsx new file mode 100644 index 00000000..9752455b --- /dev/null +++ b/src/frontend/apps/conversations/src/features/chat/components/TruncatedResponseMessage.tsx @@ -0,0 +1,24 @@ +import { useTranslation } from 'react-i18next'; + +import { Box, Text } from '@/components'; + +export const TruncatedResponseMessage = () => { + const { t } = useTranslation(); + + return ( + + + {t( + 'The response was cut off because it reached the maximum length. Try rephrasing your question to get a shorter answer.', + )} + + + ); +}; diff --git a/src/frontend/apps/conversations/src/features/chat/components/__tests__/MessageItem.test.tsx b/src/frontend/apps/conversations/src/features/chat/components/__tests__/MessageItem.test.tsx index 959493de..5ea773f9 100644 --- a/src/frontend/apps/conversations/src/features/chat/components/__tests__/MessageItem.test.tsx +++ b/src/frontend/apps/conversations/src/features/chat/components/__tests__/MessageItem.test.tsx @@ -50,6 +50,12 @@ jest.mock('../ToolInvocationItem', () => ({ ToolInvocationItem: () =>
, })); +jest.mock('../TruncatedResponseMessage', () => ({ + TruncatedResponseMessage: () => ( +
+ ), +})); + describe('splitIntoBlocks', () => { describe('basic splitting', () => { it('returns empty array for empty content', () => { @@ -540,3 +546,84 @@ describe('MessageItem', () => { }); }); }); + +describe('truncated response annotation', () => { + const baseProps = { + isLastMessage: true, + isLastAssistantMessage: true, + isFirstConversationMessage: false, + streamingMessageHeight: null, + status: 'ready' as const, + conversationId: 'conv-1', + isSourceOpen: null, + isMobile: false, + onCopyToClipboard: jest.fn(), + onOpenSources: jest.fn(), + getMetadata: jest.fn(), + }; + + it('renders TruncatedResponseMessage when annotations contain truncated:true', () => { + const message = { + id: 'msg-1', + role: 'assistant' as const, + content: 'Hello world', + annotations: [{ truncated: true }], + parts: [], + }; + + render( + + + + + + + , + ); + + expect(screen.getByTestId('truncated-response-message')).toBeInTheDocument(); + }); + + it('does not render TruncatedResponseMessage when annotations are absent', () => { + const message = { + id: 'msg-1', + role: 'assistant' as const, + content: 'Hello world', + parts: [], + }; + + render( + + + + + + + , + ); + + expect(screen.queryByTestId('truncated-response-message')).not.toBeInTheDocument(); + }); + + it('does not render TruncatedResponseMessage when annotations do not contain truncated:true', () => { + const message = { + id: 'msg-1', + role: 'assistant' as const, + content: 'Hello world', + annotations: [{ co2_impact: 0.001 }], + parts: [], + }; + + render( + + + + + + + , + ); + + expect(screen.queryByTestId('truncated-response-message')).not.toBeInTheDocument(); + }); +});