From 97be429e43c57ca29b24e462e2fca2ac6770c9eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Spie=C3=9F?= Date: Wed, 14 Jan 2026 19:59:55 +0100 Subject: [PATCH] Add synchronization to DaveSessionManager --- .../jdave/manager/DaveSessionManager.java | 94 +++++++++++++++---- 1 file changed, 78 insertions(+), 16 deletions(-) diff --git a/api/src/main/java/club/minnced/discord/jdave/manager/DaveSessionManager.java b/api/src/main/java/club/minnced/discord/jdave/manager/DaveSessionManager.java index f0df2ea..9b08f78 100644 --- a/api/src/main/java/club/minnced/discord/jdave/manager/DaveSessionManager.java +++ b/api/src/main/java/club/minnced/discord/jdave/manager/DaveSessionManager.java @@ -30,6 +30,7 @@ public class DaveSessionManager implements AutoCloseable { private final Map preparedTransitions = new ConcurrentHashMap<>(); private int currentProtocolVersion = DISABLED_PROTOCOL_VERSION; + private volatile boolean shutdown = false; private DaveSessionManager(long selfUserId, long channelId, @NonNull DaveSessionManagerCallbacks callbacks) { this(selfUserId, channelId, callbacks, DaveSessionImpl.create(null)); @@ -63,7 +64,8 @@ public static DaveSessionManager create( } @Override - public void close() { + public synchronized void close() { + shutdown = true; encryptor.close(); decryptors.values().forEach(DaveDecryptor::close); decryptors.clear(); @@ -74,15 +76,27 @@ public int getMaxProtocolVersion() { return LibDave.getMaxSupportedProtocolVersion(); } - public void assignSsrcToCodec(@NonNull DaveCodec codec, int ssrc) { + public synchronized void assignSsrcToCodec(@NonNull DaveCodec codec, int ssrc) { + if (shutdown) { + return; + } + encryptor.assignSsrcToCodec(codec, ssrc); } - public int getMaxEncryptedFrameSize(@NonNull DaveMediaType type, int frameSize) { + public synchronized int getMaxEncryptedFrameSize(@NonNull DaveMediaType type, int frameSize) { + if (shutdown) { + return frameSize; + } + return (int) encryptor.getMaxCiphertextByteSize(type, frameSize); } - public int getMaxDecryptedFrameSize(@NonNull DaveMediaType type, long userId, int frameSize) { + public synchronized int getMaxDecryptedFrameSize(@NonNull DaveMediaType type, long userId, int frameSize) { + if (shutdown) { + return frameSize; + } + DaveDecryptor decryptor = this.decryptors.get(userId); if (decryptor == null) { return frameSize; @@ -92,15 +106,23 @@ public int getMaxDecryptedFrameSize(@NonNull DaveMediaType type, long userId, in } @NonNull - public DaveEncryptResultType encrypt( + public synchronized DaveEncryptResultType encrypt( @NonNull DaveMediaType type, int ssrc, @NonNull ByteBuffer audio, @NonNull ByteBuffer encrypted) { + if (shutdown) { + return DaveEncryptResultType.FAILURE; + } + DaveEncryptor.DaveEncryptorResult result = encryptor.encrypt(type, ssrc, audio, encrypted); return result.type(); } @NonNull - public DaveDecryptResultType decrypt( + public synchronized DaveDecryptResultType decrypt( @NonNull DaveMediaType type, long userId, @NonNull ByteBuffer encrypted, @NonNull ByteBuffer decrypted) { + if (shutdown) { + return DaveDecryptResultType.FAILURE; + } + DaveDecryptor decryptor = decryptors.get(userId); if (decryptor != null) { @@ -111,13 +133,21 @@ public DaveDecryptResultType decrypt( } @SuppressWarnings("resource") - public void addUser(long userId) { + public synchronized void addUser(long userId) { + if (shutdown) { + return; + } + log.debug("Adding user {}", userId); DaveDecryptor decryptor = decryptors.computeIfAbsent(userId, id -> DaveDecryptor.create(id, session)); decryptor.prepareTransition(currentProtocolVersion); } - public void removeUser(long userId) { + public synchronized void removeUser(long userId) { + if (shutdown) { + return; + } + log.debug("Removing user {}", userId); DaveDecryptor decryptor = decryptors.remove(userId); if (decryptor != null) { @@ -125,12 +155,20 @@ public void removeUser(long userId) { } } - public void onSelectProtocolAck(int protocolVersion) { + public synchronized void onSelectProtocolAck(int protocolVersion) { + if (shutdown) { + return; + } + log.debug("Handle select protocol version {}", protocolVersion); handleDaveProtocolInit(protocolVersion); } - public void onDaveProtocolPrepareTransition(int transitionId, int protocolVersion) { + public synchronized void onDaveProtocolPrepareTransition(int transitionId, int protocolVersion) { + if (shutdown) { + return; + } + log.debug( "Handle dave protocol prepare transition transitionId={} protocolVersion={}", transitionId, @@ -139,27 +177,47 @@ public void onDaveProtocolPrepareTransition(int transitionId, int protocolVersio prepareProtocolTransition(transitionId, protocolVersion); } - public void onDaveProtocolExecuteTransition(int transitionId) { + public synchronized void onDaveProtocolExecuteTransition(int transitionId) { + if (shutdown) { + return; + } + log.debug("Handle dave protocol execute transition transitionId={}", transitionId); executeProtocolTransition(transitionId); } - public void onDaveProtocolPrepareEpoch(long epoch, int protocolVersion) { + public synchronized void onDaveProtocolPrepareEpoch(long epoch, int protocolVersion) { + if (shutdown) { + return; + } + log.debug("Handle dave protocol prepare epoch epoch={} protocolVersion={}", epoch, protocolVersion); handlePrepareEpoch(epoch, (short) protocolVersion); } - public void onDaveProtocolMLSExternalSenderPackage(@NonNull ByteBuffer externalSenderPackage) { + public synchronized void onDaveProtocolMLSExternalSenderPackage(@NonNull ByteBuffer externalSenderPackage) { + if (shutdown) { + return; + } + log.debug("Handling external sender package"); session.setExternalSender(externalSenderPackage); } - public void onMLSProposals(@NonNull ByteBuffer proposals) { + public synchronized void onMLSProposals(@NonNull ByteBuffer proposals) { + if (shutdown) { + return; + } + log.debug("Handling MLS proposals"); session.processProposals(proposals, getRecognizedUserIds(), callbacks::sendMLSCommitWelcome); } - public void onMLSPrepareCommitTransition(int transitionId, @NonNull ByteBuffer commit) { + public synchronized void onMLSPrepareCommitTransition(int transitionId, @NonNull ByteBuffer commit) { + if (shutdown) { + return; + } + log.debug("Handling MLS prepare commit transition transitionId={}", transitionId); DaveSessionImpl.CommitResult result = session.processCommit(commit); switch (result) { @@ -177,7 +235,11 @@ public void onMLSPrepareCommitTransition(int transitionId, @NonNull ByteBuffer c } } - public void onMLSWelcome(int transitionId, @NonNull ByteBuffer welcome) { + public synchronized void onMLSWelcome(int transitionId, @NonNull ByteBuffer welcome) { + if (shutdown) { + return; + } + log.debug("Handling MLS welcome transition transitionId={}", transitionId); boolean joinedGroup = session.processWelcome(welcome, getRecognizedUserIds());