Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
Expand Down Expand Up @@ -105,6 +107,9 @@ r, extractGenericType(method)))
.onErrorResume(this::handleError))
.onErrorResume(this::handleError);

} else if (returnType == Flux.class) {
return invokeFlux(toolObject, method, input, agent, context, emitter, converter);

} else {
// Sync method: wrap in Mono.fromCallable
return Mono.fromCallable(
Expand All @@ -119,6 +124,101 @@ r, extractGenericType(method)))
}
}

private Mono<ToolResultBlock> invokeFlux(
Object toolObject,
Method method,
Map<String, Object> input,
Agent agent,
ToolExecutionContext context,
ToolEmitter emitter,
ToolResultConverter converter) {
Type itemType = extractGenericType(method);

return Mono.fromCallable(
() -> {
method.setAccessible(true);
Object[] args =
convertParameters(method, input, agent, context, emitter);
@SuppressWarnings("unchecked")
Flux<Object> flux = (Flux<Object>) method.invoke(toolObject, args);
return flux != null ? flux : Flux.empty();
})
.flatMap(
flux ->
flux.doOnNext(
item ->
emitFluxChunk(
emitter, converter, item, itemType))
.collectList()
.map(
items ->
converter.convert(
aggregateFluxItems(items, itemType),
resolveFluxAggregateType(
items, itemType)))
.onErrorResume(this::handleError))
.onErrorResume(this::handleError);
}

private void emitFluxChunk(
ToolEmitter emitter, ToolResultConverter converter, Object item, Type itemType) {
if (item == null) {
return;
}
emitter.emit(toStreamingChunk(item, itemType, converter));
}

private ToolResultBlock toStreamingChunk(
Object item, Type itemType, ToolResultConverter converter) {
if (item instanceof ToolResultBlock) {
return (ToolResultBlock) item;
}
if (item instanceof CharSequence
|| item instanceof Number
|| item instanceof Boolean
|| item instanceof Character) {
return ToolResultBlock.text(String.valueOf(item));
}
return converter.convert(item, itemType);
}

private Object aggregateFluxItems(List<Object> items, Type itemType) {
if (shouldConcatenateFluxItems(items, itemType)) {
StringBuilder aggregated = new StringBuilder();
for (Object item : items) {
if (item != null) {
aggregated.append(item);
}
}
return aggregated.toString();
}
if (items.isEmpty()) {
return null;
}
if (items.size() == 1) {
return items.get(0);
}
return items;
}

private Type resolveFluxAggregateType(List<Object> items, Type itemType) {
if (shouldConcatenateFluxItems(items, itemType)) {
return String.class;
}
if (items.size() == 1) {
return itemType;
}
return List.class;
}

private boolean shouldConcatenateFluxItems(List<Object> items, Type itemType) {
if (itemType == String.class || itemType == CharSequence.class) {
return true;
}
return !items.isEmpty()
&& items.stream().allMatch(item -> item == null || item instanceof CharSequence);
}

/**
* Convert input parameters to method arguments with automatic injection support.
*
Expand Down Expand Up @@ -363,7 +463,7 @@ private ToolResultBlock handleInvocationError(Throwable e) {
}

/**
* Extract generic type from method return type (for CompletableFuture<T> or Mono<T>).
* Extract generic type from method return type (for CompletableFuture<T>, Mono<T>, or Flux<T>).
*
* @param method the method
* @return the generic type, or null if not found
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.agentscope.core.tool.test.SampleTools;
import io.agentscope.core.util.JsonUtils;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
Expand All @@ -34,7 +35,7 @@
import org.junit.jupiter.api.Test;

/**
* Tests for async tool execution with CompletableFuture and Mono return types.
* Tests for async tool execution with CompletableFuture, Mono, and Flux return types.
*/
@Tag("unit")
@DisplayName("Async Tool Tests")
Expand Down Expand Up @@ -90,6 +91,56 @@ void shouldExecuteMonoAsyncTool() {
assertEquals("\"HelloWorld\"", extractFirstText(response));
}

@Test
@DisplayName("Should execute Flux async tool")
void shouldExecuteFluxAsyncTool() {
Map<String, Object> input = Map.of("str1", "Hello", "str2", "World");
ToolUseBlock toolCall =
ToolUseBlock.builder()
.id("call-async-flux")
.name("async_flux_concat")
.input(input)
.content(JsonUtils.getJsonCodec().toJson(input))
.build();

ToolResultBlock response =
toolkit.callTool(ToolCallParam.builder().toolUseBlock(toolCall).build())
.block(TIMEOUT);

assertNotNull(response, "Response should not be null");
assertEquals("\"HelloWorld\"", extractFirstText(response));
}

@Test
@DisplayName("Should emit Flux chunks while aggregating final tool result")
void shouldEmitFluxChunksWhileAggregatingFinalToolResult() {
List<String> chunkToolIds = new ArrayList<>();
List<String> chunkTexts = new ArrayList<>();
toolkit.setChunkCallback(
(toolUse, chunk) -> {
chunkToolIds.add(toolUse.getId());
chunkTexts.add(extractFirstText(chunk));
});

Map<String, Object> input = Map.of("str1", "Alpha", "str2", "Beta");
ToolUseBlock toolCall =
ToolUseBlock.builder()
.id("call-async-flux-chunk")
.name("async_flux_concat")
.input(input)
.content(JsonUtils.getJsonCodec().toJson(input))
.build();

ToolResultBlock response =
toolkit.callTool(ToolCallParam.builder().toolUseBlock(toolCall).build())
.block(TIMEOUT);

assertNotNull(response, "Response should not be null");
assertEquals(List.of("call-async-flux-chunk", "call-async-flux-chunk"), chunkToolIds);
assertEquals(List.of("Alpha", "Beta"), chunkTexts);
assertEquals("\"AlphaBeta\"", extractFirstText(response));
}

@Test
@DisplayName("Should execute async tool with delay")
void shouldExecuteAsyncToolWithDelay() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
Expand Down Expand Up @@ -205,6 +206,26 @@ public Mono<String> suspendToolMonoSync(
@ToolParam(name = "reason", description = "reason") String reason) {
throw new ToolSuspendException(reason);
}

public Flux<String> fluxConcat(
@ToolParam(name = "prefix", description = "prefix") String prefix,
@ToolParam(name = "suffix", description = "suffix") String suffix) {
return Flux.just(prefix, suffix);
}

public Flux<Integer> fluxSingleNumber(
@ToolParam(name = "value", description = "value") Integer value) {
return Flux.just(value);
}

public Flux<Integer> fluxNumbers(
@ToolParam(name = "start", description = "start") Integer start) {
return Flux.just(start, start + 1, start + 2);
}

public Flux<String> emptyFluxString() {
return Flux.empty();
}
}

/** Test POJO for generic type testing (Issue #677). */
Expand Down Expand Up @@ -867,6 +888,75 @@ void testGenericMap_WithCustomClassValue() throws Exception {
}

/** Test nested generic types like List&lt;List&lt;Integer&gt;&gt;. */
@Test
void testFluxStringAggregationAndChunkEmission() throws Exception {
TestTools tools = new TestTools();
Method method = TestTools.class.getMethod("fluxConcat", String.class, String.class);

Map<String, Object> input = new HashMap<>();
input.put("prefix", "Hello");
input.put("suffix", "World");

List<String> emittedChunks = new ArrayList<>();
ToolUseBlock toolUseBlock = new ToolUseBlock("flux-id", method.getName(), input);
ToolCallParam param =
ToolCallParam.builder()
.toolUseBlock(toolUseBlock)
.input(input)
.emitter(chunk -> emittedChunks.add(ToolTestUtils.extractContent(chunk)))
.build();

ToolResultBlock response =
invoker.invokeAsync(tools, method, param, responseConverter).block();

Assertions.assertNotNull(response);
Assertions.assertFalse(ToolTestUtils.isErrorResponse(response));
Assertions.assertEquals("\"HelloWorld\"", ToolTestUtils.extractContent(response));
Assertions.assertEquals(List.of("Hello", "World"), emittedChunks);
}

@Test
void testFluxSingleValueAggregation() throws Exception {
TestTools tools = new TestTools();
Method method = TestTools.class.getMethod("fluxSingleNumber", Integer.class);

Map<String, Object> input = new HashMap<>();
input.put("value", 7);

ToolResultBlock response = invokeWithParam(tools, method, input);

Assertions.assertNotNull(response);
Assertions.assertFalse(ToolTestUtils.isErrorResponse(response));
Assertions.assertEquals("7", ToolTestUtils.extractContent(response));
}

@Test
void testFluxMultipleValuesAggregateToJsonArray() throws Exception {
TestTools tools = new TestTools();
Method method = TestTools.class.getMethod("fluxNumbers", Integer.class);

Map<String, Object> input = new HashMap<>();
input.put("start", 3);

ToolResultBlock response = invokeWithParam(tools, method, input);

Assertions.assertNotNull(response);
Assertions.assertFalse(ToolTestUtils.isErrorResponse(response));
Assertions.assertEquals("[3,4,5]", ToolTestUtils.extractContent(response));
}

@Test
void testEmptyFluxStringAggregatesToEmptyString() throws Exception {
TestTools tools = new TestTools();
Method method = TestTools.class.getMethod("emptyFluxString");

ToolResultBlock response = invokeWithParam(tools, method, new HashMap<>());

Assertions.assertNotNull(response);
Assertions.assertFalse(ToolTestUtils.isErrorResponse(response));
Assertions.assertEquals("\"\"", ToolTestUtils.extractContent(response));
}

@Test
void testNestedGenericList() throws Exception {
TestTools tools = new TestTools();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.agentscope.core.tool.ToolParam;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
Expand Down Expand Up @@ -132,6 +133,18 @@ public Mono<String> asyncConcat(
return Mono.fromCallable(() -> str1 + str2);
}

/**
* Async tool using Flux that streams string chunks.
*/
@Tool(
name = "async_flux_concat",
description = "Asynchronously stream and concatenate two strings")
public Flux<String> asyncFluxConcat(
@ToolParam(name = "str1", description = "First string") String str1,
@ToolParam(name = "str2", description = "Second string") String str2) {
return Flux.just(str1, str2).delayElements(Duration.ofMillis(25));
}

/**
* Async tool using Mono that simulates delay.
*/
Expand Down
Loading