From c18bbd229cf655a44bad705301a22383c45adcd9 Mon Sep 17 00:00:00 2001 From: VelikovPetar Date: Tue, 10 Mar 2026 12:07:54 +0100 Subject: [PATCH] Fix scroll jump when returning to a channel after WS reconnect Co-Authored-By: Claude --- .../channel/internal/ChannelLogicImpl.kt | 46 +++++++++- .../channel/internal/ChannelStateImpl.kt | 20 ++++ .../channel/internal/ChannelLogicImplTest.kt | 91 +++++++++++++++++++ .../internal/ChannelStateImplMessagesTest.kt | 50 ++++++++++ 4 files changed, 203 insertions(+), 4 deletions(-) diff --git a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/channel/internal/ChannelLogicImpl.kt b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/channel/internal/ChannelLogicImpl.kt index d3340d5c02d..8ded7564ae0 100644 --- a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/channel/internal/ChannelLogicImpl.kt +++ b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/logic/channel/internal/ChannelLogicImpl.kt @@ -24,7 +24,9 @@ import io.getstream.chat.android.client.channel.ChannelMessagesUpdateLogic import io.getstream.chat.android.client.errors.isPermanent import io.getstream.chat.android.client.events.ChatEvent import io.getstream.chat.android.client.extensions.cidToTypeAndId +import io.getstream.chat.android.client.extensions.getCreatedAtOrDefault import io.getstream.chat.android.client.extensions.getCreatedAtOrNull +import io.getstream.chat.android.client.extensions.internal.NEVER import io.getstream.chat.android.client.internal.state.model.querychannels.pagination.internal.QueryChannelPaginationRequest import io.getstream.chat.android.client.internal.state.model.querychannels.pagination.internal.toAnyChannelPaginationRequest import io.getstream.chat.android.client.internal.state.plugin.state.channel.internal.ChannelStateImpl @@ -91,8 +93,8 @@ internal class ChannelLogicImpl( updateDataForChannel( channel = channel, messageLimit = query.messagesLimit(), + shouldRefreshMessages = true, // Note: The following arguments are NOT used. But they are kept for backwards compatibility. - shouldRefreshMessages = query.shouldRefresh, scrollUpdate = false, isNotificationUpdate = query.isNotificationUpdate, isChannelsStateUpdate = true, @@ -302,13 +304,39 @@ internal class ChannelLogicImpl( state.setChannelConfig(channel.config) // Set pending messages state.setPendingMessages(channel.pendingMessages.map(PendingMessage::message)) - // Reset messages (ensure they are sorted - when coming from DB) + // Update messages based on the relationship between the incoming page and existing state. if (messageLimit > 0) { val sortedMessages = withContext(Dispatchers.Default) { channel.messages.sortedBy { it.getCreatedAtOrNull() } } - state.setMessages(sortedMessages) - state.setEndOfOlderMessages(channel.messages.size < messageLimit) + val currentMessages = state.messages.value + when { + shouldRefreshMessages || currentMessages.isEmpty() -> { + // Initial load (DB seed or first fetch) or explicit refresh — full replace + state.setMessages(sortedMessages) + state.setEndOfOlderMessages(channel.messages.size < messageLimit) + } + state.insideSearch.value -> { + // User's window was already trimmed away from the latest (insideSearch set by + // trimNewestMessages, or a prior jump-to-message). Stay at current position; + // refresh the "jump to latest" cache with the server's current latest page. + state.upsertCachedLatestMessages(sortedMessages) + } + hasGap(currentMessages, sortedMessages) -> { + // Incoming page is newer than the current window with no overlap. Inserting the + // incoming messages would create a fragmented list. Instead, treat the user's + // position as a mid-page: store the incoming as the "latest" cache and signal the UI. + state.upsertCachedLatestMessages(sortedMessages) + state.setInsideSearch(true) + state.setEndOfNewerMessages(false) + } + else -> { + // Incoming messages are contiguous with (or overlap) the current window. + // Upsert preserves the user's scroll position while adding/updating messages. + state.upsertMessages(sortedMessages) + state.setEndOfOlderMessages(channel.messages.size < messageLimit) + } + } } // Add pinned messages state.addPinnedMessages(channel.pinnedMessages) @@ -428,4 +456,14 @@ internal class ChannelLogicImpl( // Enrich the channel with messages return channel.copy(messages = messages) } + + private fun hasGap(currentMessages: List, incomingMessages: List): Boolean { + val currentNewest = currentMessages.maxByOrNull { it.getCreatedAtOrDefault(NEVER) } + val incomingOldest = incomingMessages.firstOrNull() + return currentMessages.isNotEmpty() && + currentNewest != null && + incomingOldest != null && + currentMessages.none { it.id == incomingOldest.id } && + incomingOldest.getCreatedAtOrDefault(NEVER).after(currentNewest.getCreatedAtOrDefault(NEVER)) + } } diff --git a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/state/channel/internal/ChannelStateImpl.kt b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/state/channel/internal/ChannelStateImpl.kt index 25cc4bca93a..06eadae42d6 100644 --- a/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/state/channel/internal/ChannelStateImpl.kt +++ b/stream-chat-android-client/src/main/java/io/getstream/chat/android/client/internal/state/plugin/state/channel/internal/ChannelStateImpl.kt @@ -1411,6 +1411,26 @@ internal class ChannelStateImpl( _cachedLatestMessages.value = emptyList() } + /** + * Merges [messages] into the cached latest messages, replacing any existing entry + * with the same id and capping the list at [CACHED_LATEST_MESSAGES_LIMIT]. + * + * Called during reconnection to refresh the "jump to latest" cache with the server's + * current latest page without disturbing the user's active scroll position. + */ + fun upsertCachedLatestMessages(messages: List) { + if (messages.isEmpty()) return + val messagesToUpsert = messages.filterNot { shouldIgnoreUpsertion(it) } + if (messagesToUpsert.isEmpty()) return + _cachedLatestMessages.update { current -> + current.mergeSorted( + other = messagesToUpsert, + idSelector = Message::id, + comparator = MESSAGE_COMPARATOR, + ).takeLast(CACHED_LATEST_MESSAGES_LIMIT) + } + } + // endregion // region Destroy diff --git a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/channel/internal/ChannelLogicImplTest.kt b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/channel/internal/ChannelLogicImplTest.kt index 3951b1cc1a2..1dcf83180b6 100644 --- a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/channel/internal/ChannelLogicImplTest.kt +++ b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/logic/channel/internal/ChannelLogicImplTest.kt @@ -1384,6 +1384,97 @@ internal class ChannelLogicImplTest { // Then verify(stateImpl).setPendingMessages(listOf(pendingMsg)) } + + @Test + fun `should upsert messages when state has messages and incoming are contiguous`() = runTest { + val existingMsg = randomMessage(id = "existing", createdAt = Date(1000L), createdLocallyAt = null) + whenever(stateImpl.messages).thenReturn(MutableStateFlow(listOf(existingMsg))) + val incomingMsg = randomMessage(id = "new", createdAt = Date(500L), createdLocallyAt = null) + val channel = randomChannel( + id = "123", + type = "messaging", + messages = listOf(incomingMsg), + members = emptyList(), + watchers = emptyList(), + read = emptyList(), + memberCount = 0, + watcherCount = 0, + ) + sut.updateDataForChannel(channel = channel, messageLimit = 30) + verify(stateImpl).upsertMessages(listOf(incomingMsg)) + verify(stateImpl, never()).setMessages(any()) + verify(stateImpl, never()).upsertCachedLatestMessages(any()) + verify(stateImpl, never()).setEndOfNewerMessages(any()) + } + + @Test + fun `should cache incoming and signal newer messages when gap is detected`() = runTest { + val existingMsg = randomMessage(id = "old", createdAt = Date(1000L), createdLocallyAt = null) + whenever(stateImpl.messages).thenReturn(MutableStateFlow(listOf(existingMsg))) + val incomingMsg = randomMessage(id = "new", createdAt = Date(5000L), createdLocallyAt = null) + val channel = randomChannel( + id = "123", + type = "messaging", + messages = listOf(incomingMsg), + members = emptyList(), + watchers = emptyList(), + read = emptyList(), + memberCount = 0, + watcherCount = 0, + ) + sut.updateDataForChannel(channel = channel, messageLimit = 30) + verify(stateImpl).upsertCachedLatestMessages(listOf(incomingMsg)) + verify(stateImpl).setInsideSearch(true) + verify(stateImpl).setEndOfNewerMessages(false) + verify(stateImpl, never()).setMessages(any()) + verify(stateImpl, never()).upsertMessages(any()) + verify(stateImpl, never()).setEndOfOlderMessages(any()) + } + + @Test + fun `should refresh cached latest messages when already inside search`() = runTest { + val existingMsg = randomMessage(id = "mid", createdAt = Date(1000L), createdLocallyAt = null) + whenever(stateImpl.messages).thenReturn(MutableStateFlow(listOf(existingMsg))) + whenever(stateImpl.insideSearch).thenReturn(MutableStateFlow(true)) + val incomingMsg = randomMessage(id = "latest", createdAt = Date(5000L), createdLocallyAt = null) + val channel = randomChannel( + id = "123", + type = "messaging", + messages = listOf(incomingMsg), + members = emptyList(), + watchers = emptyList(), + read = emptyList(), + memberCount = 0, + watcherCount = 0, + ) + sut.updateDataForChannel(channel = channel, messageLimit = 30) + verify(stateImpl).upsertCachedLatestMessages(listOf(incomingMsg)) + verify(stateImpl, never()).setMessages(any()) + verify(stateImpl, never()).upsertMessages(any()) + verify(stateImpl, never()).setInsideSearch(any()) + verify(stateImpl, never()).setEndOfNewerMessages(any()) + } + + @Test + fun `should replace messages when shouldRefreshMessages is true regardless of existing state`() = runTest { + val existingMsg = randomMessage(id = "old", createdAt = Date(1000L), createdLocallyAt = null) + whenever(stateImpl.messages).thenReturn(MutableStateFlow(listOf(existingMsg))) + val incomingMsg = randomMessage(id = "new", createdAt = Date(5000L), createdLocallyAt = null) + val channel = randomChannel( + id = "123", + type = "messaging", + messages = listOf(incomingMsg), + members = emptyList(), + watchers = emptyList(), + read = emptyList(), + memberCount = 0, + watcherCount = 0, + ) + sut.updateDataForChannel(channel = channel, messageLimit = 30, shouldRefreshMessages = true) + verify(stateImpl).setMessages(listOf(incomingMsg)) + verify(stateImpl, never()).upsertMessages(any()) + verify(stateImpl, never()).upsertCachedLatestMessages(any()) + } } // endregion diff --git a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/state/channel/internal/ChannelStateImplMessagesTest.kt b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/state/channel/internal/ChannelStateImplMessagesTest.kt index 93304d094d7..20c354372e5 100644 --- a/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/state/channel/internal/ChannelStateImplMessagesTest.kt +++ b/stream-chat-android-client/src/test/java/io/getstream/chat/android/client/internal/state/plugin/state/channel/internal/ChannelStateImplMessagesTest.kt @@ -601,6 +601,56 @@ internal class ChannelStateImplMessagesTest { } } + @Nested + inner class UpsertCachedLatestMessages { + + @Test + fun `upsertCachedLatestMessages with empty list should not change cache`() = runTest { + // given + val messages = createMessages(3) + channelState.setMessages(messages) + channelState.cacheLatestMessages() + channelState.setMessages(emptyList()) + val before = channelState.toChannel().cachedLatestMessages + // when + channelState.upsertCachedLatestMessages(emptyList()) + // then + assertEquals(before, channelState.toChannel().cachedLatestMessages) + } + + @Test + fun `upsertCachedLatestMessages with all filtered messages should not change cache`() = runTest { + // given + val regularMsg = createMessage(1, timestamp = 1000) + channelState.setMessages(listOf(regularMsg)) + channelState.cacheLatestMessages() + channelState.setMessages(emptyList()) + val before = channelState.toChannel().cachedLatestMessages + // when — thread reply not shown in channel is always filtered out + val threadReply = createMessage(2, parentId = "parent1", showInChannel = false) + channelState.upsertCachedLatestMessages(listOf(threadReply)) + // then + assertEquals(before, channelState.toChannel().cachedLatestMessages) + } + + @Test + fun `upsertCachedLatestMessages should merge incoming messages into the cache`() = runTest { + // given + val msg1 = createMessage(1, timestamp = 1000) + val msg5 = createMessage(5, timestamp = 5000) + channelState.setMessages(listOf(msg1, msg5)) + channelState.cacheLatestMessages() + channelState.setMessages(emptyList()) + // when + val msg3 = createMessage(3, timestamp = 3000) + channelState.upsertCachedLatestMessages(listOf(msg3)) + // then + val cachedMessages = channelState.toChannel().cachedLatestMessages + assertEquals(3, cachedMessages.size) + assertEquals(listOf("message_1", "message_3", "message_5"), cachedMessages.map { it.id }) + } + } + @Nested inner class GetMessageById {