Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

186 changes: 144 additions & 42 deletions src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,52 +91,84 @@ impl AgentRunner {
return Ok(());
}

// Create an internal channel to collect events from this LLM call
let (inner_tx, mut inner_rx) = mpsc::channel::<AgentEvent>(32);

provider
.call_streaming(&current_messages, tools, system_prompt, inner_tx)
.await?;

// Collect events, forwarding text/usage/error to client,
// collecting tool_use calls for dispatch
let mut tool_calls: Vec<(String, String, serde_json::Value)> = Vec::new();
let tool_calls: Vec<(String, String, serde_json::Value)> = {
// Create an internal channel to collect events from this LLM call.
// We must drain this channel while provider streaming is in-flight;
// otherwise long responses can fill the buffer and deadlock.
let (inner_tx, mut inner_rx) = mpsc::channel::<AgentEvent>(32);
let mut tool_calls: Vec<(String, String, serde_json::Value)> = Vec::new();
let mut provider_done = false;
let mut provider_error: Option<anyhow::Error> = None;

let provider_call =
provider.call_streaming(&current_messages, tools, system_prompt, inner_tx);
tokio::pin!(provider_call);

loop {
tokio::select! {
result = &mut provider_call, if !provider_done => {
provider_done = true;
if let Err(e) = result {
provider_error = Some(e);
}
}
maybe_event = inner_rx.recv() => {
let Some(event) = maybe_event else {
if !provider_done {
if let Err(e) = (&mut provider_call).await {
provider_error = Some(e);
}
}
break;
};

while let Some(event) = inner_rx.recv().await {
match event {
AgentEvent::Text(ref _t) => {
let _ = tx.send(event).await;
}
AgentEvent::ToolUse {
ref id,
ref name,
ref input,
} => {
// Forward to client so they can observe
let _ = tx
.send(AgentEvent::ToolUse {
id: id.clone(),
name: name.clone(),
input: input.clone(),
})
.await;
tool_calls.push((id.clone(), name.clone(), input.clone()));
}
AgentEvent::Usage { .. } => {
let _ = tx.send(event).await;
}
AgentEvent::Error(ref _e) => {
let _ = tx.send(event).await;
}
AgentEvent::Done => {
// Don't forward Done yet — we may need to continue the loop
match event {
AgentEvent::Text(ref _t) => {
let _ = tx.send(event).await;
}
AgentEvent::ToolUse {
ref id,
ref name,
ref input,
} => {
// Forward to client so they can observe
let _ = tx
.send(AgentEvent::ToolUse {
id: id.clone(),
name: name.clone(),
input: input.clone(),
})
.await;
tool_calls.push((id.clone(), name.clone(), input.clone()));
}
AgentEvent::Usage { .. } => {
let _ = tx.send(event).await;
}
AgentEvent::Error(ref _e) => {
let _ = tx.send(event).await;
}
AgentEvent::Done => {
// Don't forward Done yet — we may need to continue the loop
}
AgentEvent::ToolResult { .. } => {
// Shouldn't come from provider, but forward if it does
let _ = tx.send(event).await;
}
}
}
}
AgentEvent::ToolResult { .. } => {
// Shouldn't come from provider, but forward if it does
let _ = tx.send(event).await;

if provider_done && inner_rx.is_closed() && inner_rx.is_empty() {
break;
}
}
}

if let Some(e) = provider_error {
return Err(e);
}

tool_calls
};

// If no tool calls, we're done
if tool_calls.is_empty() {
Expand Down Expand Up @@ -366,3 +398,73 @@ impl Default for AgentRunner {
Self::new()
}
}

#[cfg(test)]
mod tests {
use super::{AgentEvent, AgentRunner, providers};
use crate::sandbox::PluginHost;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use tokio::time::{Duration, timeout};

struct BurstProvider {
chunks: usize,
}

#[async_trait]
impl providers::LlmProvider for BurstProvider {
async fn call_streaming(
&self,
_messages: &[serde_json::Value],
_tools: &[serde_json::Value],
_system_prompt: Option<&str>,
tx: mpsc::Sender<AgentEvent>,
) -> anyhow::Result<()> {
for i in 0..self.chunks {
tx.send(AgentEvent::Text(format!("chunk-{i}"))).await?;
}
tx.send(AgentEvent::Done).await?;
Ok(())
}
}

#[tokio::test]
async fn run_with_tools_drains_stream_while_provider_is_running() {
let runner = AgentRunner::new();
let provider = BurstProvider { chunks: 64 };
let plugins = Arc::new(RwLock::new(PluginHost::new()));
let (tx, mut rx) = mpsc::channel::<AgentEvent>(256);

timeout(
Duration::from_secs(5),
runner.run_with_tools(
&provider,
vec![serde_json::json!({
"role": "user",
"content": "hello",
})],
&[],
None,
&plugins,
tx,
),
)
.await
.expect("runner should not deadlock")
.expect("runner should succeed");

let mut text_count = 0usize;
let mut saw_done = false;
while let Ok(event) = rx.try_recv() {
match event {
AgentEvent::Text(_) => text_count += 1,
AgentEvent::Done => saw_done = true,
_ => {}
}
}

assert_eq!(text_count, 64);
assert!(saw_done);
}
}
Loading
Loading