Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ protected Mono<Msg> doCall(List<Msg> msgs) {
LoggerUtil.debug(log, "[{}] A2aAgent call with input messages: ", currentRequestId);
LoggerUtil.logTextMsgDetail(log, memory.getMessages());
clientEventContext.setHooks(getSortedHooks());
clientEventContext.setInputMessages(memory.getMessages());
return Mono.defer(
() -> {
Message message =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@
import io.a2a.spec.Task;
import io.agentscope.core.a2a.agent.A2aAgent;
import io.agentscope.core.hook.Hook;
import io.agentscope.core.hook.PostReasoningEvent;
import io.agentscope.core.hook.PreReasoningEvent;
import io.agentscope.core.hook.ReasoningChunkEvent;
import io.agentscope.core.message.Msg;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;

/**
Expand All @@ -40,6 +45,16 @@ public class ClientEventContext {

private Task task;

/**
* Temporarily store the complete historical dialogue context at the time of this call,
* specifically for use in constructing PreReasoning Events using the {@link #publishPreReasoning()} method.
*/
private List<Msg> inputMessages;

// Ensure that lifecycle events are triggered only once
private final AtomicBoolean preReasoningFired = new AtomicBoolean(false);
private final AtomicBoolean postReasoningFired = new AtomicBoolean(false);

public ClientEventContext(String currentRequestId, A2aAgent agent) {
this.currentRequestId = currentRequestId;
this.agent = agent;
Expand Down Expand Up @@ -76,4 +91,70 @@ public Task getTask() {
public void setTask(Task task) {
this.task = task;
}

public void setInputMessages(List<Msg> inputMessages) {
this.inputMessages = inputMessages;
}

// ==========================================
// Unified Event Publishing API
// ==========================================

/**
* Trigger PreReasoningEvent (triggered only once)
*/
void publishPreReasoning() {
if (hooks != null && !hooks.isEmpty() && preReasoningFired.compareAndSet(false, true)) {
List<Msg> msgs = inputMessages == null ? List.of() : inputMessages;
PreReasoningEvent preEvent = new PreReasoningEvent(agent, "A2A", null, msgs);

Mono<PreReasoningEvent> eventMono = Mono.just(preEvent);
for (Hook hook : hooks) {
eventMono = eventMono.flatMap(hook::onEvent);
}
eventMono.block();
}
}

/**
* Trigger ReasoningChunkEvent (streaming process)
*/
void publishReasoningChunk(Msg chunkMsg) {
if (hooks != null && !hooks.isEmpty()) {
publishPreReasoning(); // If not sent Pre before, send Pre first
ReasoningChunkEvent chunkEvent =
new ReasoningChunkEvent(agent, "A2A", null, chunkMsg, chunkMsg);

Mono<ReasoningChunkEvent> eventMono = Mono.just(chunkEvent);
for (Hook hook : hooks) {
eventMono = eventMono.flatMap(hook::onEvent);
}
eventMono.block();
}
}

/**
* Trigger PostReasoningEvent (triggered only once) and return the final reasoning message
* after hooks have had a chance to modify it.
*
* @param finalMsg the original final reasoning message
* @return the hook-modified reasoning message, or {@code finalMsg} if no hooks ran or no
* modification was applied
*/
Msg publishPostReasoning(Msg finalMsg) {
if (hooks != null && !hooks.isEmpty() && postReasoningFired.compareAndSet(false, true)) {
publishPreReasoning();
PostReasoningEvent postEvent = new PostReasoningEvent(agent, "A2A", null, finalMsg);

Mono<PostReasoningEvent> eventMono = Mono.just(postEvent);
for (Hook hook : hooks) {
eventMono = eventMono.flatMap(hook::onEvent);
}
postEvent = eventMono.block();

Msg modifiedMsg = postEvent.getReasoningMessage();
return modifiedMsg != null ? modifiedMsg : finalMsg;
}
return finalMsg;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ public void handle(MessageEvent event, ClientEventContext context) {
Msg msg =
MessageConvertUtil.convertFromMessage(
event.getMessage(), context.getAgent().getName());

// Automatically trigger PreReasoningEvent and PostReasoningEvent
msg = context.publishPostReasoning(msg);

context.getSink().success(msg);
LoggerUtil.info(log, "[{}] A2aAgent complete call.", currentRequestId);
LoggerUtil.debug(log, "[{}] A2aAgent complete with artifact messages: ", currentRequestId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,7 @@ public void handle(TaskEvent event, ClientEventContext context) {
context.getCurrentRequestId(),
task.getId(),
task.getStatus());

context.publishPreReasoning();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import io.a2a.spec.UpdateEvent;
import io.agentscope.core.a2a.agent.utils.LoggerUtil;
import io.agentscope.core.a2a.agent.utils.MessageConvertUtil;
import io.agentscope.core.hook.ReasoningChunkEvent;
import io.agentscope.core.message.Msg;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -93,6 +92,9 @@ public void handle(TaskStatusUpdateEvent event, ClientEventContext context) {
Msg msg =
MessageConvertUtil.convertFromArtifact(
context.getTask().getArtifacts(), context.getAgent().getName());

msg = context.publishPostReasoning(msg);

context.getSink().success(msg);
LoggerUtil.info(log, "[{}] A2aAgent complete call.", currentRequestId);
LoggerUtil.debug(
Expand All @@ -114,9 +116,8 @@ public void handle(TaskStatusUpdateEvent event, ClientEventContext context) {
LoggerUtil.debug(
log, "[{}] A2aAgent task status updated with messages: ", currentRequestId);
LoggerUtil.logTextMsgDetail(log, List.of(msg));
ReasoningChunkEvent chunkEvent =
new ReasoningChunkEvent(context.getAgent(), "A2A", null, msg, msg);
context.getHooks().forEach(hook -> hook.onEvent(chunkEvent).block());

context.publishReasoningChunk(msg);
}
}
}
Expand All @@ -136,9 +137,8 @@ public void handle(TaskArtifactUpdateEvent event, ClientEventContext context) {
LoggerUtil.debug(
log, "[{}] A2aAgent artifact append with messages: ", currentRequestTaskId);
LoggerUtil.logTextMsgDetail(log, List.of(msg));
ReasoningChunkEvent chunkEvent =
new ReasoningChunkEvent(context.getAgent(), "A2A", null, msg, msg);
context.getHooks().forEach(hook -> hook.onEvent(chunkEvent).block());

context.publishReasoningChunk(msg);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
import io.agentscope.core.agent.Event;
import io.agentscope.core.hook.Hook;
import io.agentscope.core.hook.HookEvent;
import io.agentscope.core.hook.PostReasoningEvent;
import io.agentscope.core.hook.PreCallEvent;
import io.agentscope.core.hook.PreReasoningEvent;
import io.agentscope.core.hook.ReasoningChunkEvent;
import io.agentscope.core.message.Msg;
import java.lang.reflect.Field;
import java.util.HashMap;
Expand All @@ -69,6 +72,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
Expand Down Expand Up @@ -282,9 +286,10 @@ void testStreamAgentWithDefaultTransport() {
List<Event> streamResults =
agent.stream(Msg.builder().textContent("test").build()).collectList().block();
assertNotNull(streamResults);
assertEquals(2, streamResults.size());
assertFalse(streamResults.get(0).isLast());
assertTrue(streamResults.get(1).isLast());
assertEquals(3, streamResults.size());
assertFalse(streamResults.get(0).isLast()); // ReasoningChunkEvent
assertTrue(streamResults.get(1).isLast()); // PostReasoningEvent
assertTrue(streamResults.get(2).isLast()); // AGENT_RESULT
}

@Test
Expand Down Expand Up @@ -428,6 +433,113 @@ void testCallAgentWithDefaultTransportByObserve() {
assertEquals(3, agent.getMemory().getMessages().size());
}

@Test
@DisplayName("Should trigger Pre, Chunk, Post reasoning events")
void testAgentLifecycleHooksTriggeredCorrectly() {
AtomicInteger preCount = new AtomicInteger(0);
AtomicInteger chunkCount = new AtomicInteger(0);
AtomicInteger postCount = new AtomicInteger(0);

Hook lifecycleMonitorHook =
new Hook() {
@Override
public <T extends HookEvent> Mono<T> onEvent(T event) {
if (event instanceof PreReasoningEvent) {
preCount.incrementAndGet();
} else if (event instanceof ReasoningChunkEvent) {
chunkCount.incrementAndGet();
} else if (event instanceof PostReasoningEvent) {
postCount.incrementAndGet();
}
return Mono.just(event);
}

@Override
public int priority() {
return 1;
}
};

A2aAgent agent =
A2aAgent.builder()
.name("test-lifecycle-agent")
.agentCard(agentCard)
.hook(new ReplaceA2aClientHook())
.hook(lifecycleMonitorHook)
.build();

Answer<Void> mockTaskResponse =
invocation -> {
@SuppressWarnings("unchecked")
List<BiConsumer<ClientEvent, AgentCard>> a2aEventConsumer =
invocation.getArgument(1, List.class);

// Task creation
Task initialTask =
new Task.Builder()
.id("t1")
.contextId("c1")
.status(new TaskStatus(TaskState.WORKING))
.build();
a2aEventConsumer.forEach(c -> c.accept(new TaskEvent(initialTask), agentCard));

// Stream output a piece of text (Artifact Update)
TaskArtifactUpdateEvent chunkEvent =
new TaskArtifactUpdateEvent.Builder()
.taskId("t1")
.contextId("c1")
.artifact(
new Artifact.Builder()
.artifactId("a1")
.name("mockArtifact")
.parts(new TextPart("Hello A2A"))
.build())
.build();
Task workingTask =
new Task.Builder()
.id("t1")
.contextId("c1")
.status(new TaskStatus(TaskState.WORKING))
.artifacts(List.of(chunkEvent.getArtifact()))
.build();
a2aEventConsumer.forEach(
c -> c.accept(new TaskUpdateEvent(workingTask, chunkEvent), agentCard));

// Task complete (Status Update - COMPLETED)
Task completedTask =
new Task.Builder()
.id("t1")
.contextId("c1")
.status(new TaskStatus(TaskState.COMPLETED))
.artifacts(List.of(chunkEvent.getArtifact()))
.build();
TaskStatusUpdateEvent completeEvent =
new TaskStatusUpdateEvent(
"t1",
new TaskStatus(TaskState.COMPLETED),
"c1",
true,
Map.of());
a2aEventConsumer.forEach(
c ->
c.accept(
new TaskUpdateEvent(completedTask, completeEvent),
agentCard));

return null;
};

doAnswer(mockTaskResponse)
.when(a2aClient)
.sendMessage(any(Message.class), anyList(), any());

agent.stream(Msg.builder().textContent("测试触发").build()).collectList().block();

assertEquals(1, preCount.get());
assertEquals(1, chunkCount.get());
assertEquals(1, postCount.get());
}

private Answer<Void> mockSuccessMessage() {
return invocationOnMock -> {
@SuppressWarnings("unchecked")
Expand Down
Loading