diff --git a/vocata-server/src/test/java/com/vocata/ai/websocket/AiChatWebSocketHandlerTest.java b/vocata-server/src/test/java/com/vocata/ai/websocket/AiChatWebSocketHandlerTest.java index 6437a2f..4165897 100644 --- a/vocata-server/src/test/java/com/vocata/ai/websocket/AiChatWebSocketHandlerTest.java +++ b/vocata-server/src/test/java/com/vocata/ai/websocket/AiChatWebSocketHandlerTest.java @@ -1,10 +1,11 @@ package com.vocata.ai.websocket; -import com.vocata.ai.service.AiStreamingService; -import com.vocata.conversation.service.ConversationService; +import com.vocata.ai.pipeline.PipelineEvent; +import com.vocata.ai.pipeline.StreamingPipelineOrchestrator; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.web.socket.BinaryMessage; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; @@ -12,7 +13,9 @@ import reactor.core.publisher.Flux; import java.io.IOException; +import java.lang.reflect.Constructor; import java.net.URI; +import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; import java.util.UUID; @@ -20,10 +23,12 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -32,112 +37,106 @@ class AiChatWebSocketHandlerTest { - private final AiStreamingService aiStreamingService = mock(AiStreamingService.class); - private final ConversationService conversationService = mock(ConversationService.class); + private static final String SESSION_ID = "session-1"; + private static final String USER_ID = "42"; private AiChatWebSocketHandler handler; @BeforeEach void setUp() { handler = new AiChatWebSocketHandler(); - ReflectionTestUtils.setField(handler, "aiStreamingService", aiStreamingService); - ReflectionTestUtils.setField(handler, "conversationService", conversationService); } @Test void audioStartBeginsVoiceProcessingImmediatelyAndAudioEndOnlyCompletesStream() throws Exception { WebSocketSession session = mockVoiceSession(); - AtomicInteger subscriptions = new AtomicInteger(); + StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class); + AtomicInteger voiceMessageCalls = new AtomicInteger(); AtomicBoolean completed = new AtomicBoolean(); - when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any())) + when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any())) .thenAnswer(invocation -> { Flux audioStream = invocation.getArgument(2); + voiceMessageCalls.incrementAndGet(); return audioStream - .doOnSubscribe(ignored -> subscriptions.incrementAndGet()) .doOnComplete(() -> completed.set(true)) - .thenMany(Flux.empty()); + .thenMany(Flux.just(new PipelineEvent.Complete())); }); + installSessionState(session, orchestrator); + handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}")); - assertEquals(1, subscriptions.get()); - verify(aiStreamingService, times(1)) - .processVoiceMessage(eq(conversationUuid(session)), eq("42"), any()); + assertEquals(1, voiceMessageCalls.get()); + verify(orchestrator, times(1)) + .processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any()); handler.handleMessage(session, new TextMessage("{\"type\":\"audio_end\"}")); assertTrue(completed.get()); - verify(aiStreamingService, times(1)) - .processVoiceMessage(eq(conversationUuid(session)), eq("42"), any()); - assertTrue(voiceSessions().isEmpty()); + assertTrue(sessions().containsKey(SESSION_ID)); + assertSessionAudioClosed(); } @Test void audioCancelDisposesActiveVoiceSession() throws Exception { WebSocketSession session = mockVoiceSession(); - AtomicBoolean cancelled = new AtomicBoolean(); + StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class); - when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any())) - .thenAnswer(invocation -> { - Flux audioStream = invocation.getArgument(2); - return audioStream.thenMany(Flux.>never() - .doOnCancel(() -> cancelled.set(true))); - }); + when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any())) + .thenReturn(Flux.never()); + + installSessionState(session, orchestrator); handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}")); handler.handleMessage(session, new TextMessage("{\"type\":\"audio_cancel\"}")); - assertTrue(cancelled.get()); - assertTrue(voiceSessions().isEmpty()); + assertTrue(sessions().containsKey(SESSION_ID)); + assertSessionAudioClosed(); } @Test void afterConnectionClosedDisposesActiveVoiceSession() throws Exception { WebSocketSession session = mockVoiceSession(); - AtomicBoolean cancelled = new AtomicBoolean(); + StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class); - when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any())) - .thenAnswer(invocation -> { - Flux audioStream = invocation.getArgument(2); - return audioStream.thenMany(Flux.>never() - .doOnCancel(() -> cancelled.set(true))); - }); + when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any())) + .thenReturn(Flux.never()); + + installSessionState(session, orchestrator); handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}")); handler.afterConnectionClosed(session, CloseStatus.NORMAL); - assertTrue(cancelled.get()); - assertTrue(voiceSessions().isEmpty()); + verify(orchestrator, times(1)).dispose(); + assertFalse(sessions().containsKey(SESSION_ID)); } @Test void transportErrorDisposesActiveVoiceSession() throws Exception { WebSocketSession session = mockVoiceSession(); - AtomicBoolean cancelled = new AtomicBoolean(); + StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class); - when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any())) - .thenAnswer(invocation -> { - Flux audioStream = invocation.getArgument(2); - return audioStream.thenMany(Flux.>never() - .doOnCancel(() -> cancelled.set(true))); - }); + when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any())) + .thenReturn(Flux.never()); + + installSessionState(session, orchestrator); handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}")); handler.handleTransportError(session, new IOException("boom")); - assertTrue(cancelled.get()); - assertTrue(voiceSessions().isEmpty()); + verify(orchestrator, times(1)).dispose(); + assertFalse(sessions().containsKey(SESSION_ID)); } @Test void sttResultIsForwardedToClientWithNormalizedFields() throws Exception { WebSocketSession session = mockVoiceSession(); + StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class); AtomicInteger textMessages = new AtomicInteger(); AtomicBoolean normalizedSttSeen = new AtomicBoolean(); - doNothing().when(session).sendMessage(any(WebSocketMessage.class)); - org.mockito.Mockito.doAnswer(invocation -> { + doAnswer(invocation -> { WebSocketMessage outbound = invocation.getArgument(0); if (outbound instanceof TextMessage textMessage) { textMessages.incrementAndGet(); @@ -152,15 +151,13 @@ void sttResultIsForwardedToClientWithNormalizedFields() throws Exception { return null; }).when(session).sendMessage(any(WebSocketMessage.class)); - when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any())) - .thenReturn(Flux.just(Map.of( - "type", "stt_result", - "payload", Map.of( - "text", "你好", - "confidence", 0.75, - "is_final", false - ) - ))); + when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any())) + .thenReturn(Flux.just( + new PipelineEvent.SttResult("你好", false, 0.75), + new PipelineEvent.Complete() + )); + + installSessionState(session, orchestrator); handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}")); handler.handleMessage(session, new TextMessage("{\"type\":\"audio_end\"}")); @@ -173,11 +170,11 @@ void sttResultIsForwardedToClientWithNormalizedFields() throws Exception { @Test void duplicateAudioStartReturnsExplicitSessionError() throws Exception { WebSocketSession session = mockVoiceSession(); - AtomicInteger subscriptions = new AtomicInteger(); + StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class); + AtomicInteger voiceMessageCalls = new AtomicInteger(); AtomicBoolean duplicateStartErrorSeen = new AtomicBoolean(); - doNothing().when(session).sendMessage(any(WebSocketMessage.class)); - org.mockito.Mockito.doAnswer(invocation -> { + doAnswer(invocation -> { WebSocketMessage outbound = invocation.getArgument(0); if (outbound instanceof TextMessage textMessage) { String payload = textMessage.getPayload(); @@ -189,28 +186,50 @@ void duplicateAudioStartReturnsExplicitSessionError() throws Exception { return null; }).when(session).sendMessage(any(WebSocketMessage.class)); - when(aiStreamingService.processVoiceMessage(eq(conversationUuid(session)), eq("42"), any())) + when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any())) .thenAnswer(invocation -> { - Flux audioStream = invocation.getArgument(2); - return audioStream - .doOnSubscribe(ignored -> subscriptions.incrementAndGet()) - .thenMany(Flux.>never()); + voiceMessageCalls.incrementAndGet(); + return Flux.never(); }); + installSessionState(session, orchestrator); + handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}")); handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}")); - assertEquals(1, subscriptions.get()); + assertEquals(1, voiceMessageCalls.get()); assertTrue(duplicateStartErrorSeen.get()); } + @Test + void binaryFrameIsForwardedIntoActiveAudioSink() throws Exception { + WebSocketSession session = mockVoiceSession(); + StreamingPipelineOrchestrator orchestrator = mock(StreamingPipelineOrchestrator.class); + AtomicInteger audioChunks = new AtomicInteger(); + + when(orchestrator.processVoiceMessage(eq(conversationUuid(session)), eq(USER_ID), any())) + .thenAnswer(invocation -> { + Flux audioStream = invocation.getArgument(2); + return audioStream + .doOnNext(chunk -> audioChunks.incrementAndGet()) + .thenMany(Flux.never()); + }); + + installSessionState(session, orchestrator); + + handler.handleMessage(session, new TextMessage("{\"type\":\"audio_start\"}")); + handler.handleMessage(session, new BinaryMessage("abc".getBytes(StandardCharsets.UTF_8))); + + assertEquals(1, audioChunks.get()); + } + private WebSocketSession mockVoiceSession() throws IOException { WebSocketSession session = mock(WebSocketSession.class); Map attributes = new HashMap<>(); - attributes.put("authenticatedUserId", "42"); + attributes.put("authenticatedUserId", USER_ID); UUID conversationUuid = UUID.randomUUID(); - when(session.getId()).thenReturn("session-1"); + when(session.getId()).thenReturn(SESSION_ID); when(session.getUri()).thenReturn(URI.create("ws://localhost/ws/chat/" + conversationUuid)); when(session.getAttributes()).thenReturn(attributes); when(session.isOpen()).thenReturn(true); @@ -220,9 +239,28 @@ private WebSocketSession mockVoiceSession() throws IOException { } @SuppressWarnings("unchecked") - private Map voiceSessions() { - Object voiceSessions = ReflectionTestUtils.getField(handler, "voiceSessions"); - return voiceSessions == null ? Map.of() : (Map) voiceSessions; + private Map sessions() { + Object sessions = ReflectionTestUtils.getField(handler, "sessions"); + return sessions == null ? Map.of() : (Map) sessions; + } + + private void installSessionState(WebSocketSession session, StreamingPipelineOrchestrator orchestrator) throws Exception { + Class sessionStateClass = Class.forName("com.vocata.ai.websocket.AiChatWebSocketHandler$SessionState"); + Constructor constructor = sessionStateClass.getDeclaredConstructor(StreamingPipelineOrchestrator.class); + constructor.setAccessible(true); + Object sessionState = constructor.newInstance(orchestrator); + + @SuppressWarnings("unchecked") + Map sessions = (Map) ReflectionTestUtils.getField(handler, "sessions"); + sessions.put(session.getId(), sessionState); + } + + private void assertSessionAudioClosed() { + Object sessionState = sessions().get(SESSION_ID); + Object audioSink = ReflectionTestUtils.getField(sessionState, "audioSink"); + Object pipelineSubscription = ReflectionTestUtils.getField(sessionState, "pipelineSubscription"); + assertEquals(null, audioSink); + assertEquals(null, pipelineSubscription); } private String conversationUuid(WebSocketSession session) {