Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to
- ✅(front) add tests for SourceItem component
- ✨(documents) make document context hybrid
- 🐛(fix) add prevent_url_hallucination instruction to ConversationAgent
- ✨(back) add output token limit per message

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/env.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ These are the environment variables you can set for the `conversations-backend`
| AI_MODEL | AI Model name to use (used in default LLM configuration, not for production use) | |
| AI_AGENT_INSTRUCTIONS | Base instruction for the AI agent (used in default LLM configuration, not for production use) | You are a helpful assistant. Wrap formulas... |
| AI_AGENT_TOOLS | List of enabled tools for the agent (used in default LLM configuration, not for production use) | [] |
| LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE | Maximum number of output tokens the LLM can generate per message. When reached, the response is truncated and the user is notified. | 8192 |
| CONVERSION_API_ENDPOINT | Conversion API endpoint | convert-markdown |
| CONVERSION_API_CONTENT_FIELD | Conversion api content field | content |
| CONVERSION_API_TIMEOUT | Conversion api timeout | 30 |
Expand Down
25 changes: 25 additions & 0 deletions src/backend/chat/clients/pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
from asgiref.sync import sync_to_async
from langfuse import get_client
from pydantic_ai import Agent, InstrumentationSettings, RunContext, RunUsage
from pydantic_ai.capabilities import Hooks
from pydantic_ai.messages import (
BinaryContent,
DocumentUrl,
Expand All @@ -123,6 +124,7 @@
UserPromptPart,
)
from pydantic_ai.models import Model, infer_model_profile
from pydantic_ai.settings import ModelSettings

from core.feature_flags.helpers import is_feature_enabled

Expand Down Expand Up @@ -302,6 +304,15 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument
self._is_smart_search_enabled = user.allow_smart_web_search
self._fake_streaming_delay = settings.FAKE_STREAMING_DELAY

self._last_finish_reason: str | None = None

self._truncation_hooks = Hooks()

@self._truncation_hooks.on.after_model_request
async def _detect_truncation(_ctx, *, request_context, response): # pylint: disable=unused-argument
self._last_finish_reason = response.finish_reason
return response

self._context_deps = ContextDeps(
conversation=conversation,
user=user,
Expand All @@ -321,6 +332,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument
if self._langfuse_available
else False,
deps_type=ContextDeps,
capabilities=[self._truncation_hooks],
)
add_document_rag_search_tool_from_setting(self.conversation_agent, self.user)

Expand Down Expand Up @@ -419,6 +431,7 @@ async def _clean(self):
It can be used to release resources or perform any necessary cleanup.
"""
self._last_stop_check = 0
self._last_finish_reason = None
await cache.adelete(self._stop_cache_key)

# --------------------------------------------------------------------- #
Expand Down Expand Up @@ -1122,6 +1135,10 @@ async def _finalize_conversation( # pylint: disable=too-many-arguments,too-many
]
)
self._update_langfuse_trace(run_output)

if self._last_finish_reason == "length":
yield events_v4.MessageAnnotationPart(annotations=[{"truncated": True}])

# Vercel finish message
yield events_v4.FinishMessagePart(
finish_reason=events_v4.FinishReason.STOP,
Expand Down Expand Up @@ -1196,11 +1213,14 @@ async def _run_agent( # pylint: disable=too-many-locals
if history[-1].parts and history[-1].parts[-1].part_kind == "tool-return":
history.append(ModelResponse(parts=[TextPart(content="ok")], kind="response"))

model_settings = ModelSettings(max_tokens=settings.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE)

async with self.conversation_agent.iter(
[user_prompt] + input_images,
message_history=history, # history will pass through agent's history_processors
deps=self._context_deps,
toolsets=mcp_servers,
model_settings=model_settings,
) as run:
state = StreamingState()
async for event in self._process_agent_nodes(run, state, langfuse):
Expand Down Expand Up @@ -1278,6 +1298,11 @@ def _prepare_update_conversation(
_output_ui_message.annotations = []
_output_ui_message.annotations.append({"co2_impact": co2_impact})

if self._last_finish_reason == "length":
if _output_ui_message.annotations is None:
_output_ui_message.annotations = []
_output_ui_message.annotations.append({"truncated": True})

usage["co2_impact"] += self.conversation.agent_usage.get("co2_impact", 0)
usage["promptTokens"] += self.conversation.agent_usage.get("promptTokens", 0)
usage["completionTokens"] += self.conversation.agent_usage.get("completionTokens", 0)
Expand Down
201 changes: 201 additions & 0 deletions src/backend/chat/tests/clients/pydantic_ai/test_output_token_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Tests for the output token limit feature."""

from unittest.mock import MagicMock, patch

from django.conf import settings as django_settings

import pytest
from asgiref.sync import sync_to_async
from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.settings import ModelSettings

from chat.ai_sdk_types import TextUIPart, UIMessage
from chat.clients.pydantic_ai import AIAgentService
from chat.factories import ChatConversationFactory

pytestmark = pytest.mark.django_db()


@pytest.fixture(autouse=True)
def base_settings(settings):
"""Set up base settings for all tests in this module."""
settings.AI_BASE_URL = "https://api.llm.com/v1/"
settings.AI_API_KEY = "test-key"
settings.AI_MODEL = "model-123"
settings.AI_AGENT_INSTRUCTIONS = "You are a helpful assistant"
settings.AI_AGENT_TOOLS = []
settings.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE = 1000


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------


@pytest.fixture(name="ui_messages")
def ui_messages_fixture():
"""Fixture providing a single user UIMessage."""
return [
UIMessage(
id="msg-1",
role="user",
content="Hello",
parts=[TextUIPart(type="text", text="Hello")],
)
]


@pytest.fixture(name="simple_model")
def simple_model_fixture():
"""Fixture providing a simple streaming FunctionModel."""

async def _model(_messages: list[ModelMessage], _info: AgentInfo):
yield "Hello world"

return FunctionModel(stream_function=_model)

Comment thread
coderabbitai[bot] marked this conversation as resolved.

# ---------------------------------------------------------------------------
# Settings tests
# ---------------------------------------------------------------------------


def test_llm_max_output_tokens_per_message_setting_exists():
"""Setting must exist with a positive integer value."""
assert hasattr(django_settings, "LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE")
assert isinstance(django_settings.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE, int)
assert django_settings.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE == 1000


# ---------------------------------------------------------------------------
# Hook / flag tests
# ---------------------------------------------------------------------------


def test_response_truncated_flag_initialized_to_false():
"""AIAgentService must start with _last_finish_reason=None."""
conversation = ChatConversationFactory()
service = AIAgentService(conversation, user=conversation.owner)
assert service._last_finish_reason is None # pylint: disable=protected-access


@pytest.mark.asyncio
async def test_response_truncated_flag_reset_by_clean():
"""_clean() must reset _last_finish_reason to None."""
conversation = await sync_to_async(ChatConversationFactory)()
service = AIAgentService(conversation, user=conversation.owner)
service._last_finish_reason = "length" # pylint: disable=protected-access
await service._clean() # pylint: disable=protected-access
assert service._last_finish_reason is None # pylint: disable=protected-access


@pytest.mark.asyncio
async def test_hook_sets_truncated_flag_when_finish_reason_is_length():
"""after_model_request hook must set _last_finish_reason='length' for finish_reason='length'."""
conversation = await sync_to_async(ChatConversationFactory)()
service = AIAgentService(conversation, user=conversation.owner)

mock_response = ModelResponse(
parts=[TextPart(content="Hello wo...")],
finish_reason="length",
)
mock_ctx = MagicMock()
mock_request_context = MagicMock()

result = await service.conversation_agent._root_capability.after_model_request( # pylint: disable=protected-access
mock_ctx, request_context=mock_request_context, response=mock_response
)

assert service._last_finish_reason == "length" # pylint: disable=protected-access
assert result == mock_response


@pytest.mark.asyncio
async def test_hook_does_not_set_truncated_flag_when_finish_reason_is_stop():
"""after_model_request hook must set _last_finish_reason='stop' for finish_reason='stop'."""
conversation = await sync_to_async(ChatConversationFactory)()
service = AIAgentService(conversation, user=conversation.owner)

mock_response = ModelResponse(
parts=[TextPart(content="Hello world")],
finish_reason="stop",
)
mock_ctx = MagicMock()
mock_request_context = MagicMock()

await service.conversation_agent._root_capability.after_model_request( # pylint: disable=protected-access
mock_ctx, request_context=mock_request_context, response=mock_response
)

assert service._last_finish_reason == "stop" # pylint: disable=protected-access


# ---------------------------------------------------------------------------
# Stream event tests
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_run_agent_passes_max_tokens_model_settings(ui_messages):
"""_run_agent must pass
ModelSettings(max_tokens=LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE) to agent.iter()."""
conversation = await sync_to_async(ChatConversationFactory)()
service = AIAgentService(conversation, user=conversation.owner)

captured_kwargs = {}
original_iter = service.conversation_agent.iter

def capturing_iter(*args, **kwargs):
captured_kwargs.update(kwargs)
return original_iter(*args, **kwargs)

async def simple_model(_messages: list[ModelMessage], _info: AgentInfo):
yield "ok"

with service.conversation_agent.override(model=FunctionModel(stream_function=simple_model)):
with patch.object(service.conversation_agent, "iter", side_effect=capturing_iter):
async for _ in service.stream_data_async(ui_messages):
pass

Check warning on line 159 in src/backend/chat/tests/clients/pydantic_ai/test_output_token_limit.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Either remove or fill this block of code.

See more on https://sonarcloud.io/project/issues?id=suitenumerique_conversations&issues=AZ4DHhCPkUtQvV1GLvlS&open=AZ4DHhCPkUtQvV1GLvlS&pullRequest=458

assert "model_settings" in captured_kwargs
assert captured_kwargs["model_settings"] == ModelSettings(max_tokens=1000)


@pytest.mark.asyncio
async def test_truncation_annotation_emitted_when_flag_is_set(ui_messages):
"""When the model returns finish_reason='length', stream must emit
MessageAnnotationPart with truncated=True."""
conversation = await sync_to_async(ChatConversationFactory)()
service = AIAgentService(conversation, user=conversation.owner)
service._support_streaming = False # pylint: disable=protected-access

def _length_model(_messages: list[ModelMessage], _info: AgentInfo):
return ModelResponse(parts=[TextPart(content="Hello world")], finish_reason="length")

length_model = FunctionModel(function=_length_model)

chunks = []
with service.conversation_agent.override(model=length_model):
async for chunk in service.stream_data_async(ui_messages):
chunks.append(chunk)

# The MessageAnnotationPart is encoded as a line containing '"truncated"'
annotation_chunks = [c for c in chunks if '"truncated"' in c]
assert len(annotation_chunks) == 1
assert '"truncated"' in annotation_chunks[0]


@pytest.mark.asyncio
async def test_truncation_annotation_not_emitted_when_flag_is_false(ui_messages, simple_model):
"""When _last_finish_reason is not 'length', stream must NOT emit a truncation annotation."""
conversation = await sync_to_async(ChatConversationFactory)()
service = AIAgentService(conversation, user=conversation.owner)

chunks = []
with service.conversation_agent.override(model=simple_model):
async for chunk in service.stream_data_async(ui_messages):
chunks.append(chunk)

annotation_chunks = [c for c in chunks if '"truncated"' in c]
assert len(annotation_chunks) == 0
1 change: 1 addition & 0 deletions src/backend/chat/tests/test_ai_agent_service_co2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def service_fixture(conversation):
"""Instantiate AIAgentService without __init__, injecting only the conversation."""
s = object.__new__(AIAgentService)
s.conversation = conversation
s._last_finish_reason = None
return s


Expand Down
5 changes: 5 additions & 0 deletions src/backend/conversations/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,11 @@ class Base(BraveSettings, Configuration):
environ_name="DEFAULT_ALLOW_CONVERSATION_ANALYTICS",
environ_prefix=None,
)
LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE = values.IntegerValue(
8192,
environ_name="LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE",
environ_prefix=None,
)
Comment on lines +681 to +685
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate token limit as strictly positive at startup.

A non-positive value here can flow into generation settings and break response generation at runtime. Add a fail-fast validation in post_setup.

Suggested fix
 class Base(BraveSettings, Configuration):
@@
     `@classmethod`
     def post_setup(cls):
@@
+        if cls.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE <= 0:
+            raise ValueError(
+                "LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE must be > 0, "
+                f"got {cls.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE}."
+            )
+
         # Document context budget ratio must be a fraction (0 disables full inlining,
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/backend/conversations/settings.py` around lines 681 - 685,
ConversationsSettings defines LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE but doesn't
validate it; add a fail-fast check in the class's post_setup method that reads
self.LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE and raises a ValueError (or logger +
sys.exit) if the value is <= 0, with a clear message identifying
LLM_MAX_OUTPUT_TOKENS_PER_MESSAGE so startup fails early and prevents invalid
token limits from propagating into generation settings.

DEFAULT_ALLOW_SMART_WEB_SEARCH = values.BooleanValue(
default=False,
environ_name="DEFAULT_ALLOW_SMART_WEB_SEARCH",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
} from '@/features/chat/components/MessageBlock';
import { SourceItemList } from '@/features/chat/components/SourceItemList';
import { ToolInvocationItem } from '@/features/chat/components/ToolInvocationItem';
import { TruncatedResponseMessage } from '@/features/chat/components/TruncatedResponseMessage';

// Memoized blocks list to prevent parent re-renders from causing block remounts
const BlocksList = React.memo(
Expand Down Expand Up @@ -242,6 +243,19 @@ const MessageItemComponent: React.FC<MessageItemProps> = ({
);
}, [toolInvocationParts]);

const isTruncated = React.useMemo(() => {
return (
message.role === 'assistant' &&
Array.isArray(message.annotations) &&
message.annotations.some(
(a) =>
typeof a === 'object' &&
a !== null &&
(a as Record<string, unknown>).truncated === true,
)
);
}, [message.annotations, message.role]);

const activeToolInvocation = React.useMemo(() => {
const tool = toolInvocationParts.find(
(part) => part.toolInvocation.toolName !== 'document_parsing',
Expand Down Expand Up @@ -478,6 +492,7 @@ const MessageItemComponent: React.FC<MessageItemProps> = ({
<SourceItemList parts={sourceParts} getMetadata={getMetadata} />
</Box>
)}
{isTruncated && <TruncatedResponseMessage />}
</Box>
</Box>
</Box>
Expand Down Expand Up @@ -510,6 +525,23 @@ const arePropsEqual = (
return false;
}

// Check annotations changes (for truncation and other metadata)
const prevAnnotationsLength = prevProps.message.annotations?.length ?? 0;
const nextAnnotationsLength = nextProps.message.annotations?.length ?? 0;
if (prevAnnotationsLength !== nextAnnotationsLength) {
return false;
}
const isTruncated = (annotations: typeof prevProps.message.annotations) =>
annotations?.some(
(a) =>
typeof a === 'object' &&
a !== null &&
(a as Record<string, unknown>).truncated === true,
) ?? false;
if (isTruncated(prevProps.message.annotations) !== isTruncated(nextProps.message.annotations)) {
return false;
}

// Check attachments
const prevAttachmentsLength =
prevProps.message.experimental_attachments?.length ?? 0;
Expand Down
Loading
Loading