Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class DaveSessionManager implements AutoCloseable {
private final Map<Integer, Integer> preparedTransitions = new ConcurrentHashMap<>();

private int currentProtocolVersion = DISABLED_PROTOCOL_VERSION;
private volatile boolean shutdown = false;

private DaveSessionManager(long selfUserId, long channelId, @NonNull DaveSessionManagerCallbacks callbacks) {
this(selfUserId, channelId, callbacks, DaveSessionImpl.create(null));
Expand Down Expand Up @@ -63,7 +64,8 @@ public static DaveSessionManager create(
}

@Override
public void close() {
public synchronized void close() {
shutdown = true;
encryptor.close();
decryptors.values().forEach(DaveDecryptor::close);
decryptors.clear();
Expand All @@ -74,15 +76,27 @@ public int getMaxProtocolVersion() {
return LibDave.getMaxSupportedProtocolVersion();
}

public void assignSsrcToCodec(@NonNull DaveCodec codec, int ssrc) {
public synchronized void assignSsrcToCodec(@NonNull DaveCodec codec, int ssrc) {
if (shutdown) {
return;
}

encryptor.assignSsrcToCodec(codec, ssrc);
}

public int getMaxEncryptedFrameSize(@NonNull DaveMediaType type, int frameSize) {
public synchronized int getMaxEncryptedFrameSize(@NonNull DaveMediaType type, int frameSize) {
if (shutdown) {
return frameSize;
}

return (int) encryptor.getMaxCiphertextByteSize(type, frameSize);
}

public int getMaxDecryptedFrameSize(@NonNull DaveMediaType type, long userId, int frameSize) {
public synchronized int getMaxDecryptedFrameSize(@NonNull DaveMediaType type, long userId, int frameSize) {
if (shutdown) {
return frameSize;
}

DaveDecryptor decryptor = this.decryptors.get(userId);
if (decryptor == null) {
return frameSize;
Expand All @@ -92,15 +106,23 @@ public int getMaxDecryptedFrameSize(@NonNull DaveMediaType type, long userId, in
}

@NonNull
public DaveEncryptResultType encrypt(
public synchronized DaveEncryptResultType encrypt(
@NonNull DaveMediaType type, int ssrc, @NonNull ByteBuffer audio, @NonNull ByteBuffer encrypted) {
if (shutdown) {
return DaveEncryptResultType.FAILURE;
}

DaveEncryptor.DaveEncryptorResult result = encryptor.encrypt(type, ssrc, audio, encrypted);
return result.type();
}

@NonNull
public DaveDecryptResultType decrypt(
public synchronized DaveDecryptResultType decrypt(
@NonNull DaveMediaType type, long userId, @NonNull ByteBuffer encrypted, @NonNull ByteBuffer decrypted) {
if (shutdown) {
return DaveDecryptResultType.FAILURE;
}

DaveDecryptor decryptor = decryptors.get(userId);

if (decryptor != null) {
Expand All @@ -111,26 +133,42 @@ public DaveDecryptResultType decrypt(
}

@SuppressWarnings("resource")
public void addUser(long userId) {
public synchronized void addUser(long userId) {
if (shutdown) {
return;
}

log.debug("Adding user {}", userId);
DaveDecryptor decryptor = decryptors.computeIfAbsent(userId, id -> DaveDecryptor.create(id, session));
decryptor.prepareTransition(currentProtocolVersion);
}

public void removeUser(long userId) {
public synchronized void removeUser(long userId) {
if (shutdown) {
return;
}

log.debug("Removing user {}", userId);
DaveDecryptor decryptor = decryptors.remove(userId);
if (decryptor != null) {
decryptor.close();
}
}

public void onSelectProtocolAck(int protocolVersion) {
public synchronized void onSelectProtocolAck(int protocolVersion) {
if (shutdown) {
return;
}

log.debug("Handle select protocol version {}", protocolVersion);
handleDaveProtocolInit(protocolVersion);
}

public void onDaveProtocolPrepareTransition(int transitionId, int protocolVersion) {
public synchronized void onDaveProtocolPrepareTransition(int transitionId, int protocolVersion) {
if (shutdown) {
return;
}

log.debug(
"Handle dave protocol prepare transition transitionId={} protocolVersion={}",
transitionId,
Expand All @@ -139,27 +177,47 @@ public void onDaveProtocolPrepareTransition(int transitionId, int protocolVersio
prepareProtocolTransition(transitionId, protocolVersion);
}

public void onDaveProtocolExecuteTransition(int transitionId) {
public synchronized void onDaveProtocolExecuteTransition(int transitionId) {
if (shutdown) {
return;
}

log.debug("Handle dave protocol execute transition transitionId={}", transitionId);
executeProtocolTransition(transitionId);
}

public void onDaveProtocolPrepareEpoch(long epoch, int protocolVersion) {
public synchronized void onDaveProtocolPrepareEpoch(long epoch, int protocolVersion) {
if (shutdown) {
return;
}

log.debug("Handle dave protocol prepare epoch epoch={} protocolVersion={}", epoch, protocolVersion);
handlePrepareEpoch(epoch, (short) protocolVersion);
}

public void onDaveProtocolMLSExternalSenderPackage(@NonNull ByteBuffer externalSenderPackage) {
public synchronized void onDaveProtocolMLSExternalSenderPackage(@NonNull ByteBuffer externalSenderPackage) {
if (shutdown) {
return;
}

log.debug("Handling external sender package");
session.setExternalSender(externalSenderPackage);
}

public void onMLSProposals(@NonNull ByteBuffer proposals) {
public synchronized void onMLSProposals(@NonNull ByteBuffer proposals) {
if (shutdown) {
return;
}

log.debug("Handling MLS proposals");
session.processProposals(proposals, getRecognizedUserIds(), callbacks::sendMLSCommitWelcome);
}

public void onMLSPrepareCommitTransition(int transitionId, @NonNull ByteBuffer commit) {
public synchronized void onMLSPrepareCommitTransition(int transitionId, @NonNull ByteBuffer commit) {
if (shutdown) {
return;
}

log.debug("Handling MLS prepare commit transition transitionId={}", transitionId);
DaveSessionImpl.CommitResult result = session.processCommit(commit);
switch (result) {
Expand All @@ -177,7 +235,11 @@ public void onMLSPrepareCommitTransition(int transitionId, @NonNull ByteBuffer c
}
}

public void onMLSWelcome(int transitionId, @NonNull ByteBuffer welcome) {
public synchronized void onMLSWelcome(int transitionId, @NonNull ByteBuffer welcome) {
if (shutdown) {
return;
}

log.debug("Handling MLS welcome transition transitionId={}", transitionId);
boolean joinedGroup = session.processWelcome(welcome, getRecognizedUserIds());

Expand Down