Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 83 additions & 23 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,14 @@ class LlmAgent(BaseAgent):
- Extracts agent reply for later use, such as in tools, callbacks, etc.
- Connects agents to coordinate with each other.
"""
accumulate_output_key: bool = True
"""Whether to accumulate streamed output fragments into `output_key`.

When True (default) streamed fragments received before tool calls are
appended into the session state under `output_key` so the final saved value
includes all streamed text. When False, preserves legacy behavior where
only the final response's text from the last event is saved.
"""
# Controlled input/output configurations - End

# Advance features - Start
Expand Down Expand Up @@ -486,7 +494,7 @@ async def _run_async_impl(
should_pause = False
async with Aclosing(self._llm_flow.run_async(ctx)) as agen:
async for event in agen:
self.__maybe_save_output_to_state(event)
self.__maybe_save_output_to_state(event, ctx)
yield event
if ctx.should_pause_invocation(event):
# Do not pause immediately, wait until the long-running tool call is
Expand All @@ -510,7 +518,7 @@ async def _run_live_impl(
) -> AsyncGenerator[Event, None]:
async with Aclosing(self._llm_flow.run_live(ctx)) as agen:
async for event in agen:
self.__maybe_save_output_to_state(event)
self.__maybe_save_output_to_state(event, ctx)
yield event
if ctx.end_invocation:
return
Expand Down Expand Up @@ -827,8 +835,16 @@ def __get_transfer_to_agent_or_none(
return self.__get_agent_to_run(event.actions.transfer_to_agent)
return None

def __maybe_save_output_to_state(self, event: Event):
"""Saves the model output to state if needed."""
def __maybe_save_output_to_state(
self, event: Event, ctx: Optional[InvocationContext] = None
):
"""Saves the model output to state if needed.

Backwards-compatible: if `ctx` is None, keeps the original behavior of
only saving on final responses. If `ctx` is provided, append streamed
partial text to the existing session state so intermediate streamed
fragments are not lost when tools are called.
"""
# skip if the event was authored by some other agent (e.g. current agent
# transferred to another agent)
if event.author != self.name:
Expand All @@ -842,33 +858,75 @@ def __maybe_save_output_to_state(self, event: Event):
if not self.output_key:
return

# Handle text responses
if event.is_final_response() and event.content and event.content.parts:
# Collect text parts from this event
if not (event.content and event.content.parts):
return

# Skip if no text parts at all to avoid overwriting state_delta values
# already set (e.g. after_tool_callback with skip_summarization
# on function_response-only events).
has_text_part = any(
part.text is not None and not part.thought
for part in event.content.parts
)
result = ''.join(
part.text for part in event.content.parts if part.text and not part.thought
)

if not has_text_part:
# If no invocation context was provided, preserve legacy behavior: only
# save on final responses and apply schema validation then.
if ctx is None:
if not event.is_final_response():
return

result = ''.join(
part.text
for part in event.content.parts
if part.text and not part.thought
)
if self.output_schema:
# If the result from the final chunk is just whitespace or empty,
# it means this is an empty final chunk of a stream.
# Do not attempt to parse it as JSON.
if not result.strip():
return
result = validate_schema(self.output_schema, result)
elif not result:
return
event.actions.state_delta[self.output_key] = result
return

# When ctx is provided, append partial streamed results into session
# state so earlier streamed text is preserved across tool calls. If the
# caller disabled accumulation via `accumulate_output_key`, fall back to
# legacy behavior: ignore non-final fragments and save only the final
# fragment (without combining previous fragments).
# Read the existing value from the session (may be empty).
try:
previous = ctx.session.state.get(self.output_key, '') or ''
except Exception:
previous = ''
# If accumulation disabled, ignore non-final fragments and save only
# the final fragment as legacy behavior.
if not self.accumulate_output_key:
if not event.is_final_response():
return
# Final-only behavior: validate only the final fragment.
if self.output_schema:
if not result.strip():
return
validated = validate_schema(self.output_schema, result)
event.actions.state_delta[self.output_key] = validated
return
if not result:
return
event.actions.state_delta[self.output_key] = result
return

# Accumulation enabled: Final response combines previous + result
# then validate and save. Non-final events append current fragment to
# previous value so it is available to future finalization.
if event.is_final_response():
combined = (previous or '') + (result or '')
if not combined:
return
if self.output_schema:
if not combined.strip():
return
validated = validate_schema(self.output_schema, combined)
event.actions.state_delta[self.output_key] = validated
return
event.actions.state_delta[self.output_key] = combined
return

# Non-final (streaming) response: append the fragment to previous value.
if result:
event.actions.state_delta[self.output_key] = previous + result
return

@model_validator(mode='after')
def __model_validator_after(self) -> LlmAgent:
Expand Down Expand Up @@ -1000,6 +1058,8 @@ def _parse_config(
kwargs['output_schema'] = resolve_code_reference(config.output_schema)
if config.output_key:
kwargs['output_key'] = config.output_key
if getattr(config, 'accumulate_output_key', None) is not None:
kwargs['accumulate_output_key'] = config.accumulate_output_key
if config.tools:
kwargs['tools'] = cls._resolve_tools(config.tools, config_abs_path)
if config.before_model_callbacks:
Expand Down
5 changes: 5 additions & 0 deletions src/google/adk/agents/llm_agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def _validate_model_sources(self) -> LlmAgentConfig:
default=None, description='Optional. LlmAgent.output_key.'
)

accumulate_output_key: Optional[bool] = Field(
default=None,
description='Optional. When true, streamed fragments are accumulated into the `output_key` across tool calls. When false, only the final response is saved to `output_key`.',
)

include_contents: Literal['default', 'none'] = Field(
default='default', description='Optional. LlmAgent.include_contents.'
)
Expand Down
60 changes: 44 additions & 16 deletions tests/unittests/agents/test_llm_agent_output_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,24 +310,52 @@ def test_maybe_save_output_to_state_skips_function_response_only_event(self):
# The callback-set value should be preserved, not overwritten with ""
assert event.actions.state_delta["result"] == [1, 2, 3]

def test_maybe_save_output_to_state_saves_empty_string_when_text_is_empty(
self,
):
"""Test that output is saved as empty string when part.text is explicitly empty."""
def test_accumulate_output_key_toggle(self):
"""Test that `accumulate_output_key` controls accumulation behavior.

Simulate two streamed fragments separated by a tool call by manually
updating the session state between calls.
"""
class Ctx:
pass

# Prepare a fake invocation context with session.state
ctx = Ctx()
ctx.session = type('S', (), {'state': {}})()

# Case 1: accumulation enabled (default)
agent = LlmAgent(name="test_agent", output_key="result")

# Explicitly construct a part with empty string text
parts = [types.Part(text="")]
content = types.Content(role="model", parts=parts)
event = Event(
invocation_id="test_invocation",
author="test_agent",
content=content,
actions=EventActions(),
# First (partial) fragment
event1 = create_test_event(
author="test_agent", content_text="Intro: ", is_final=False
)
agent._LlmAgent__maybe_save_output_to_state(event1, ctx)
# Simulate session update that runner would do
ctx.session.state["result"] = event1.actions.state_delta.get("result", "")

agent._LlmAgent__maybe_save_output_to_state(event)
# Final fragment
event2 = create_test_event(author="test_agent", content_text="Conclusion", is_final=True)
agent._LlmAgent__maybe_save_output_to_state(event2, ctx)

# With accumulation enabled, final saved value should include both parts
assert event2.actions.state_delta["result"] == "Intro: Conclusion"

# Case 2: accumulation disabled
ctx2 = Ctx()
ctx2.session = type('S', (), {'state': {}})()
agent2 = LlmAgent(name="test_agent", output_key="result", accumulate_output_key=False)

event1b = create_test_event(
author="test_agent", content_text="Intro: ", is_final=False
)
agent2._LlmAgent__maybe_save_output_to_state(event1b, ctx2)
# Simulate runner updating session with the partial (though when disabled
# we expect the partial not to be used for final save)
ctx2.session.state["result"] = event1b.actions.state_delta.get("result", "")

event2b = create_test_event(author="test_agent", content_text="Conclusion", is_final=True)
agent2._LlmAgent__maybe_save_output_to_state(event2b, ctx2)

# Assert key exists and value is empty string
assert "result" in event.actions.state_delta
assert not event.actions.state_delta["result"]
# With accumulation disabled, final saved value should be only final fragment
assert event2b.actions.state_delta["result"] == "Conclusion"