Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
package com.vocata.ai.websocket;

import com.vocata.ai.service.AiStreamingService;
import com.vocata.conversation.service.ConversationService;
import com.vocata.ai.pipeline.PipelineEvent;
import com.vocata.ai.pipeline.StreamingPipelineOrchestrator;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import reactor.core.publisher.Flux;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand All @@ -32,112 +37,106 @@

class AiChatWebSocketHandlerTest {

private final AiStreamingService aiStreamingService = mock(AiStreamingService.class);
private final ConversationService conversationService = mock(ConversationService.class);
private static final String SESSION_ID = "session-1";
private static final String USER_ID = "42";

private AiChatWebSocketHandler handler;

@BeforeEach
void setUp() {
handler = new AiChatWebSocketHandler();
ReflectionTestUtils.setField(handler, "aiStreamingService", aiStreamingService);
ReflectionTestUtils.setField(handler, "conversationService", conversationService);
}

@Test
void audioStartBeginsVoiceProcessingImmediatelyAndAudioEndOnlyCompletesStream() throws Exception {
WebSocketSession session = mockVoiceSession();
AtomicInteger subscriptions = new AtomicInteger();
StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class);
AtomicInteger voiceMessageCalls = new AtomicInteger();
AtomicBoolean completed = new AtomicBoolean();

when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any()))
when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any()))
.thenAnswer(invocation -> {
Flux<byte[]> audioStream = invocation.getArgument(2);
voiceMessageCalls.incrementAndGet();
return audioStream
.doOnSubscribe(ignored -> subscriptions.incrementAndGet())
.doOnComplete(() -> completed.set(true))
.thenMany(Flux.empty());
.thenMany(Flux.just(new PipelineEvent.Complete()));
});

installSessionState(session, orchestrator);

handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}"));

assertEquals(1, subscriptions.get());
verify(aiStreamingService, times(1))
.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any());
assertEquals(1, voiceMessageCalls.get());
verify(orchestrator, times(1))
.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any());

handler.handleMessage(session, new TextMessage("{\"type\":\"audio_end\"}"));

assertTrue(completed.get());
verify(aiStreamingService, times(1))
.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any());
assertTrue(voiceSessions().isEmpty());
assertTrue(sessions().containsKey(SESSION_ID));
assertSessionAudioClosed();
}

@Test
void audioCancelDisposesActiveVoiceSession() throws Exception {
WebSocketSession session = mockVoiceSession();
AtomicBoolean cancelled = new AtomicBoolean();
StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class);

when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any()))
.thenAnswer(invocation -> {
Flux<byte[]> audioStream = invocation.getArgument(2);
return audioStream.thenMany(Flux.<Map<String, Object>>never()
.doOnCancel(() -> cancelled.set(true)));
});
when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any()))
.thenReturn(Flux.never());

installSessionState(session, orchestrator);

handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}"));
handler.handleMessage(session, new TextMessage("{\"type\":\"audio_cancel\"}"));

assertTrue(cancelled.get());
assertTrue(voiceSessions().isEmpty());
assertTrue(sessions().containsKey(SESSION_ID));
assertSessionAudioClosed();
}

@Test
void afterConnectionClosedDisposesActiveVoiceSession() throws Exception {
WebSocketSession session = mockVoiceSession();
AtomicBoolean cancelled = new AtomicBoolean();
StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class);

when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any()))
.thenAnswer(invocation -> {
Flux<byte[]> audioStream = invocation.getArgument(2);
return audioStream.thenMany(Flux.<Map<String, Object>>never()
.doOnCancel(() -> cancelled.set(true)));
});
when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any()))
.thenReturn(Flux.never());

installSessionState(session, orchestrator);

handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}"));
handler.afterConnectionClosed(session, CloseStatus.NORMAL);

assertTrue(cancelled.get());
assertTrue(voiceSessions().isEmpty());
verify(orchestrator, times(1)).dispose();
assertFalse(sessions().containsKey(SESSION_ID));
}

@Test
void transportErrorDisposesActiveVoiceSession() throws Exception {
WebSocketSession session = mockVoiceSession();
AtomicBoolean cancelled = new AtomicBoolean();
StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class);

when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any()))
.thenAnswer(invocation -> {
Flux<byte[]> audioStream = invocation.getArgument(2);
return audioStream.thenMany(Flux.<Map<String, Object>>never()
.doOnCancel(() -> cancelled.set(true)));
});
when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any()))
.thenReturn(Flux.never());

installSessionState(session, orchestrator);

handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}"));
handler.handleTransportError(session, new IOException("boom"));

assertTrue(cancelled.get());
assertTrue(voiceSessions().isEmpty());
verify(orchestrator, times(1)).dispose();
assertFalse(sessions().containsKey(SESSION_ID));
}

@Test
void sttResultIsForwardedToClientWithNormalizedFields() throws Exception {
WebSocketSession session = mockVoiceSession();
StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class);
AtomicInteger textMessages = new AtomicInteger();
AtomicBoolean normalizedSttSeen = new AtomicBoolean();

doNothing().when(session).sendMessage(any(WebSocketMessage.class));
org.mockito.Mockito.doAnswer(invocation -> {
doAnswer(invocation -> {
WebSocketMessage<?> outbound = invocation.getArgument(0);
if (outbound instanceof TextMessage textMessage) {
textMessages.incrementAndGet();
Expand All @@ -152,15 +151,13 @@ void sttResultIsForwardedToClientWithNormalizedFields() throws Exception {
return null;
}).when(session).sendMessage(any(WebSocketMessage.class));

when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any()))
.thenReturn(Flux.just(Map.of(
"type", "stt_result",
"payload", Map.of(
"text", "你好",
"confidence", 0.75,
"is_final", false
)
)));
when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any()))
.thenReturn(Flux.just(
new PipelineEvent.SttResult("你好", false, 0.75),
new PipelineEvent.Complete()
));

installSessionState(session, orchestrator);

handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}"));
handler.handleMessage(session, new TextMessage("{\"type\":\"audio_end\"}"));
Expand All @@ -173,11 +170,11 @@ void sttResultIsForwardedToClientWithNormalizedFields() throws Exception {
@Test
void duplicateAudioStartReturnsExplicitSessionError() throws Exception {
WebSocketSession session = mockVoiceSession();
AtomicInteger subscriptions = new AtomicInteger();
StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class);
AtomicInteger voiceMessageCalls = new AtomicInteger();
AtomicBoolean duplicateStartErrorSeen = new AtomicBoolean();

doNothing().when(session).sendMessage(any(WebSocketMessage.class));
org.mockito.Mockito.doAnswer(invocation -> {
doAnswer(invocation -> {
WebSocketMessage<?> outbound = invocation.getArgument(0);
if (outbound instanceof TextMessage textMessage) {
String payload = textMessage.getPayload();
Expand All @@ -189,28 +186,50 @@ void duplicateAudioStartReturnsExplicitSessionError() throws Exception {
return null;
}).when(session).sendMessage(any(WebSocketMessage.class));

when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any()))
when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any()))
.thenAnswer(invocation -> {
Flux<byte[]> audioStream = invocation.getArgument(2);
return audioStream
.doOnSubscribe(ignored -> subscriptions.incrementAndGet())
.thenMany(Flux.<Map<String, Object>>never());
voiceMessageCalls.incrementAndGet();
return Flux.never();
});

installSessionState(session, orchestrator);

handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}"));
handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}"));

assertEquals(1, subscriptions.get());
assertEquals(1, voiceMessageCalls.get());
assertTrue(duplicateStartErrorSeen.get());
}

@Test
void binaryFrameIsForwardedIntoActiveAudioSink() throws Exception {
WebSocketSession session = mockVoiceSession();
StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class);
AtomicInteger audioChunks = new AtomicInteger();

when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any()))
.thenAnswer(invocation -> {
Flux<byte[]> audioStream = invocation.getArgument(2);
return audioStream
.doOnNext(chunk -> audioChunks.incrementAndGet())
.thenMany(Flux.never());
});

installSessionState(session, orchestrator);

handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}"));
handler.handleMessage(session, new BinaryMessage("abc".getBytes(StandardCharsets.UTF_8)));

assertEquals(1, audioChunks.get());
}

private WebSocketSession mockVoiceSession() throws IOException {
WebSocketSession session = mock(WebSocketSession.class);
Map<String, Object> attributes = new HashMap<>();
attributes.put("authenticatedUserId", "42");
attributes.put("authenticatedUserId", USER_ID);

UUID conversationUuid = UUID.randomUUID();
when(session.getId()).thenReturn("session-1");
when(session.getId()).thenReturn(SESSION_ID);
when(session.getUri()).thenReturn(URI.create("ws://localhost/ws/chat/" + conversationUuid));
when(session.getAttributes()).thenReturn(attributes);
when(session.isOpen()).thenReturn(true);
Expand All @@ -220,9 +239,28 @@ private WebSocketSession mockVoiceSession() throws IOException {
}

@SuppressWarnings("unchecked")
private Map<String, Object> voiceSessions() {
Object voiceSessions = ReflectionTestUtils.getField(handler, "voiceSessions");
return voiceSessions == null ? Map.of() : (Map<String, Object>) voiceSessions;
private Map<String, Object> sessions() {
Object sessions = ReflectionTestUtils.getField(handler, "sessions");
return sessions == null ? Map.of() : (Map<String, Object>) sessions;
}

private void installSessionState(WebSocketSession session, StreamingPipelineOrchestrator orchestrator) throws Exception {
Class<?> sessionStateClass = Class.forName("com.vocata.ai.websocket.AiChatWebSocketHandler$SessionState");
Constructor<?> constructor = sessionStateClass.getDeclaredConstructor(StreamingPipelineOrchestrator.class);
constructor.setAccessible(true);
Object sessionState = constructor.newInstance(orchestrator);

Comment on lines +247 to +252
Copy link

Copilot AI Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test setup currently constructs the private inner class AiChatWebSocketHandler$SessionState via reflection (Class.forName + setAccessible). This couples the tests to handler internals and makes them brittle under refactors. If feasible, consider adding a small package-private/@VisibleForTesting hook on AiChatWebSocketHandler to register a mocked StreamingPipelineOrchestrator for a session (or extracting SessionState into a package-private type) so tests can avoid reflective access.

Copilot uses AI. Check for mistakes.
@SuppressWarnings("unchecked")
Map<String, Object> sessions = (Map<String, Object>) ReflectionTestUtils.getField(handler, "sessions");
sessions.put(session.getId(), sessionState);
}

private void assertSessionAudioClosed() {
Object sessionState = sessions().get(SESSION_ID);
Object audioSink = ReflectionTestUtils.getField(sessionState, "audioSink");
Object pipelineSubscription = ReflectionTestUtils.getField(sessionState, "pipelineSubscription");
assertEquals(null, audioSink);
assertEquals(null, pipelineSubscription);
Comment on lines +258 to +263
Copy link

Copilot AI Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use assertNull(...) (and optionally an assertion message) instead of assertEquals(null, ...) for clearer intent and better failure output when checking that audioSink / pipelineSubscription have been cleared.

Copilot uses AI. Check for mistakes.
}

private String conversationUuid(WebSocketSession session) {
Expand Down
Loading