diff --git a/deploy/pom.xml b/deploy/pom.xml
index 9d03b6e..f8972f4 100644
--- a/deploy/pom.xml
+++ b/deploy/pom.xml
@@ -5,7 +5,7 @@
com.pinterest.memq
memq-parent
- 1.0.1
+ 1.0.2-SNAPSHOT
../pom.xml
memq-deploy
diff --git a/memq-client-all/pom.xml b/memq-client-all/pom.xml
index c7044b9..425ac39 100644
--- a/memq-client-all/pom.xml
+++ b/memq-client-all/pom.xml
@@ -6,7 +6,7 @@
com.pinterest.memq
memq-parent
- 1.0.1
+ 1.0.2-SNAPSHOT
../pom.xml
memq-client-all
@@ -48,7 +48,7 @@
com.pinterest.memq
memq-client
- 1.0.1
+ 1.0.2-SNAPSHOT
diff --git a/memq-client/pom.xml b/memq-client/pom.xml
index a737a94..9882ca1 100644
--- a/memq-client/pom.xml
+++ b/memq-client/pom.xml
@@ -4,7 +4,7 @@
com.pinterest.memq
memq-parent
- 1.0.1
+ 1.0.2-SNAPSHOT
../pom.xml
memq-client
diff --git a/memq-commons/pom.xml b/memq-commons/pom.xml
index 24391fe..a0e21de 100644
--- a/memq-commons/pom.xml
+++ b/memq-commons/pom.xml
@@ -5,7 +5,7 @@
com.pinterest.memq
memq-parent
- 1.0.1
+ 1.0.2-SNAPSHOT
../pom.xml
memq-commons
diff --git a/memq-examples/pom.xml b/memq-examples/pom.xml
index d051e48..4baba57 100644
--- a/memq-examples/pom.xml
+++ b/memq-examples/pom.xml
@@ -5,7 +5,7 @@
com.pinterest.memq
memq-parent
- 1.0.1
+ 1.0.2-SNAPSHOT
../pom.xml
memq-examples
diff --git a/memq/pom.xml b/memq/pom.xml
index 58f7e46..99084c9 100644
--- a/memq/pom.xml
+++ b/memq/pom.xml
@@ -6,7 +6,7 @@
com.pinterest.memq
memq-parent
- 1.0.1
+ 1.0.2-SNAPSHOT
../pom.xml
memq
diff --git a/memq/src/main/java/com/pinterest/memq/core/MemqManager.java b/memq/src/main/java/com/pinterest/memq/core/MemqManager.java
index d3e7305..36ccb93 100644
--- a/memq/src/main/java/com/pinterest/memq/core/MemqManager.java
+++ b/memq/src/main/java/com/pinterest/memq/core/MemqManager.java
@@ -62,6 +62,7 @@ public class MemqManager implements Managed {
private static final Gson gson = new Gson();
private Map processorMap = new ConcurrentHashMap<>();
private Map topicMap = new ConcurrentHashMap<>();
+ private Map topicLastAccessMs = new ConcurrentHashMap<>();
private MemqConfig configuration;
private ScheduledExecutorService timerService;
private ScheduledExecutorService cleanupService;
@@ -89,9 +90,11 @@ public void init() throws Exception {
if (file.exists()) {
byte[] bytes = Files.readAllBytes(file.toPath());
TopicAssignment[] topics = gson.fromJson(new String(bytes), TopicAssignment[].class);
- topicMap = new ConcurrentHashMap<>();
- for (TopicAssignment topicConfig : topics) {
- topicMap.put(topicConfig.getTopic(), topicConfig);
+ if (topics != null) {
+ topicMap = new ConcurrentHashMap<>();
+ for (TopicAssignment topicConfig : topics) {
+ topicMap.put(topicConfig.getTopic(), topicConfig);
+ }
}
}
if (configuration.getTopicConfig() != null) {
@@ -99,8 +102,12 @@ public void init() throws Exception {
topicMap.put(topicConfig.getTopic(), new TopicAssignment(topicConfig, -1));
}
}
- for (Entry entry : topicMap.entrySet()) {
- createTopicProcessor(entry.getValue());
+ boolean assignmentsEnabled = configuration.getClusteringConfig() == null
+ || configuration.getClusteringConfig().isEnableAssignments();
+ if (assignmentsEnabled) {
+ for (Entry entry : topicMap.entrySet()) {
+ createTopicProcessor(entry.getValue());
+ }
}
}
@@ -148,6 +155,7 @@ public void createTopicProcessor(TopicAssignment topicConfig) throws BadRequestE
processorMap.put(topicConfig.getTopic(), tp);
topicMap.put(topicConfig.getTopic(), topicConfig);
+ topicLastAccessMs.put(topicConfig.getTopic(), System.currentTimeMillis());
logger.info("Configured and started TopicProcessor for:" + topicConfig.getTopic());
}
@@ -183,6 +191,7 @@ public Future> deleteTopicProcessor(String topic) {
}
processorMap.remove(topic);
topicMap.remove(topic);
+ topicLastAccessMs.remove(topic);
});
}
@@ -211,6 +220,47 @@ public Map getRegistry() {
return metricsRegistryMap;
}
+ public TopicAssignment getTopicAssignment(String topic) {
+ return topicMap.get(topic);
+ }
+
+ public TopicProcessor getOrCreateTopicProcessor(String topic) throws Exception {
+ TopicProcessor topicProcessor = processorMap.get(topic);
+ if (topicProcessor != null) {
+ return topicProcessor;
+ }
+ TopicAssignment assignment = topicMap.get(topic);
+ if (assignment == null) {
+ throw new NotFoundException("Topic not found:" + topic);
+ }
+ createTopicProcessor(assignment);
+ return processorMap.get(topic);
+ }
+
+ public void touchTopic(String topic) {
+ topicLastAccessMs.put(topic, System.currentTimeMillis());
+ }
+
+ public void startIdleTopicCleanup(long maxIdleMs) {
+ if (maxIdleMs <= 0) {
+ return;
+ }
+ cleanupService.scheduleAtFixedRate(() -> {
+ long now = System.currentTimeMillis();
+ for (Entry entry : topicLastAccessMs.entrySet()) {
+ String topic = entry.getKey();
+ Long lastAccess = entry.getValue();
+ if (lastAccess == null) {
+ continue;
+ }
+ if (now - lastAccess > maxIdleMs && processorMap.containsKey(topic)) {
+ logger.info("Deleting idle TopicProcessor for topic:" + topic);
+ deleteTopicProcessor(topic);
+ }
+ }
+ }, maxIdleMs, maxIdleMs, TimeUnit.MILLISECONDS);
+ }
+
@Override
public void start() throws Exception {
logger.info("Memq manager started");
diff --git a/memq/src/main/java/com/pinterest/memq/core/clustering/MemqGovernor.java b/memq/src/main/java/com/pinterest/memq/core/clustering/MemqGovernor.java
index 36c2e13..11428bd 100644
--- a/memq/src/main/java/com/pinterest/memq/core/clustering/MemqGovernor.java
+++ b/memq/src/main/java/com/pinterest/memq/core/clustering/MemqGovernor.java
@@ -16,6 +16,7 @@
package com.pinterest.memq.core.clustering;
import java.util.Collections;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -125,18 +126,22 @@ public void takeLeadership(CuratorFramework client) throws Exception {
}
});
- if (clusteringConfig.isEnableBalancer()) {
- Balancer balancer = new Balancer(config, this, client, leaderSelector);
- Thread thBalancer = new Thread(balancer);
- thBalancer.setName("BalancerThread");
- thBalancer.setDaemon(true);
- thBalancer.start();
- }
+ if (clusteringConfig.isEnableAssignments()) {
+ if (clusteringConfig.isEnableBalancer()) {
+ Balancer balancer = new Balancer(config, this, client, leaderSelector);
+ Thread thBalancer = new Thread(balancer);
+ thBalancer.setName("BalancerThread");
+ thBalancer.setDaemon(true);
+ thBalancer.start();
+ }
- Thread th = new Thread(new MetadataPoller(client, topicMetadataMap));
- th.setName("MetadataPollerThread");
- th.setDaemon(true);
- th.start();
+ Thread th = new Thread(new MetadataPoller(client, topicMetadataMap));
+ th.setName("MetadataPollerThread");
+ th.setDaemon(true);
+ th.start();
+ } else {
+ mgr.startIdleTopicCleanup(clusteringConfig.getMaxIdleMs());
+ }
if (clusteringConfig.isEnableLeaderSelector()) {
leaderSelector.autoRequeue();
@@ -171,7 +176,7 @@ private void initializeZNodesAndWatchers(CuratorFramework client) throws Excepti
client.create().withMode(CreateMode.EPHEMERAL).forPath(brokerZnodePath,
GSON.toJson(broker).getBytes());
- if (clusteringConfig.isEnableLocalAssigner()) {
+ if (clusteringConfig.isEnableAssignments() && clusteringConfig.isEnableLocalAssigner()) {
Thread th = new Thread(new TopicAssignmentWatcher(mgr, brokerZnodePath, broker, client));
th.setDaemon(true);
th.setName("TopicAssignmentWatcher");
@@ -183,6 +188,31 @@ public Map getTopicMetadataMap() {
return topicMetadataMap;
}
+ public TopicConfig getTopicConfig(String topic) throws Exception {
+ if (client == null) {
+ return null;
+ }
+ String path = ZNODE_TOPICS_BASE + topic;
+ if (client.checkExists().forPath(path) == null) {
+ return null;
+ }
+ String topicConfig = new String(client.getData().forPath(path));
+ return GSON.fromJson(topicConfig, TopicConfig.class);
+ }
+
+ public Set getAllBrokers() throws Exception {
+ Set brokers = new HashSet<>();
+ if (client == null) {
+ return brokers;
+ }
+ for (String id : client.getChildren().forPath(ZNODE_BROKERS)) {
+ String brokerInfo = new String(client.getData().forPath(ZNODE_BROKERS_BASE + id));
+ Broker broker = GSON.fromJson(brokerInfo, Broker.class);
+ brokers.add(broker);
+ }
+ return brokers;
+ }
+
public static List convertTopicAssignmentsSetToList(Set topicAssignments) {
return topicAssignments.stream().map(TopicConfig::getTopic).collect(Collectors.toList());
}
diff --git a/memq/src/main/java/com/pinterest/memq/core/config/ClusteringConfig.java b/memq/src/main/java/com/pinterest/memq/core/config/ClusteringConfig.java
index 228c999..b18f5c3 100644
--- a/memq/src/main/java/com/pinterest/memq/core/config/ClusteringConfig.java
+++ b/memq/src/main/java/com/pinterest/memq/core/config/ClusteringConfig.java
@@ -24,6 +24,8 @@ public class ClusteringConfig {
private boolean enableLocalAssigner = true;
private boolean addBootstrapTopics = true;
private boolean enableExpiration = true;
+ private boolean enableAssignments = true;
+ private long maxIdleMs = 5 * 60 * 1000; // 5 minutes
public boolean isAddBootstrapTopics() {
return addBootstrapTopics;
@@ -80,4 +82,20 @@ public boolean isEnableExpiration() {
public void setEnableExpiration(boolean enableExpiration) {
this.enableExpiration = enableExpiration;
}
+
+ public boolean isEnableAssignments() {
+ return enableAssignments;
+ }
+
+ public void setEnableAssignments(boolean enableAssignments) {
+ this.enableAssignments = enableAssignments;
+ }
+
+ public long getMaxIdleMs() {
+ return maxIdleMs;
+ }
+
+ public void setMaxIdleMs(long maxIdleMs) {
+ this.maxIdleMs = maxIdleMs;
+ }
}
diff --git a/memq/src/main/java/com/pinterest/memq/core/config/NettyServerConfig.java b/memq/src/main/java/com/pinterest/memq/core/config/NettyServerConfig.java
index e462a29..c054fd4 100644
--- a/memq/src/main/java/com/pinterest/memq/core/config/NettyServerConfig.java
+++ b/memq/src/main/java/com/pinterest/memq/core/config/NettyServerConfig.java
@@ -28,6 +28,8 @@ public class NettyServerConfig {
private int brokerInputTrafficShapingMetricsReportIntervalSec = 60; // 1 minute by default
// SSL
private SSLConfig sslConfig;
+ // Fair queueing configuration
+ private QueueingConfig queueingConfig = null;
public int getBrokerInputTrafficShapingMetricsReportIntervalSec() {
return brokerInputTrafficShapingMetricsReportIntervalSec;
@@ -94,4 +96,12 @@ public void setEnableEpoll(boolean enableEpoll) {
this.enableEpoll = enableEpoll;
}
+ public QueueingConfig getQueueingConfig() {
+ return queueingConfig;
+ }
+
+ public void setQueueingConfig(QueueingConfig queueingConfig) {
+ this.queueingConfig = queueingConfig;
+ }
+
}
diff --git a/memq/src/main/java/com/pinterest/memq/core/config/QueueingConfig.java b/memq/src/main/java/com/pinterest/memq/core/config/QueueingConfig.java
new file mode 100644
index 0000000..dd6885f
--- /dev/null
+++ b/memq/src/main/java/com/pinterest/memq/core/config/QueueingConfig.java
@@ -0,0 +1,95 @@
+/**
+ * Copyright 2022 Pinterest, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.pinterest.memq.core.config;
+
+import com.pinterest.memq.core.rpc.queue.DeficitRoundRobinStrategy;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Configuration for fair-queueing mechanism in packet processing.
+ */
+public class QueueingConfig {
+
+ private boolean enabled = false;
+
+ /**
+ * The fully qualified class name of the QueueingStrategy implementation to use.
+ * Default is DeficitRoundRobinStrategy.
+ */
+ private String strategyClass = DeficitRoundRobinStrategy.class.getName();
+
+ /**
+ * Number of threads in the dequeue thread pool.
+ * Each thread processes a subset of the topic queues.
+ */
+ private int dequeueThreadPoolSize = 4;
+
+ /**
+ * Maximum bytes of pending requests per topic queue.
+ * If a queue exceeds this limit, requests will be rejected.
+ * Default is 100MB per topic.
+ */
+ private long maxQueueBytesPerTopic = 100 * 1024 * 1024; // 100MB default
+
+ /**
+ * Strategy-specific configuration options.
+ * Keys and values depend on the strategy implementation.
+ * For example, DeficitRoundRobinStrategy uses "quantum" key.
+ */
+ private Map strategyConfig = new HashMap<>();
+
+ public boolean isEnabled() {
+ return enabled;
+ }
+
+ public void setEnabled(boolean enabled) {
+ this.enabled = enabled;
+ }
+
+ public String getStrategyClass() {
+ return strategyClass;
+ }
+
+ public void setStrategyClass(String strategyClass) {
+ this.strategyClass = strategyClass;
+ }
+
+ public int getDequeueThreadPoolSize() {
+ return dequeueThreadPoolSize;
+ }
+
+ public void setDequeueThreadPoolSize(int dequeueThreadPoolSize) {
+ this.dequeueThreadPoolSize = dequeueThreadPoolSize;
+ }
+
+ public long getMaxQueueBytesPerTopic() {
+ return maxQueueBytesPerTopic;
+ }
+
+ public void setMaxQueueBytesPerTopic(long maxQueueBytesPerTopic) {
+ this.maxQueueBytesPerTopic = maxQueueBytesPerTopic;
+ }
+
+ public Map getStrategyConfig() {
+ return strategyConfig;
+ }
+
+ public void setStrategyConfig(Map strategyConfig) {
+ this.strategyConfig = strategyConfig;
+ }
+}
diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/MemqNettyServer.java b/memq/src/main/java/com/pinterest/memq/core/rpc/MemqNettyServer.java
index c400616..747d0a1 100644
--- a/memq/src/main/java/com/pinterest/memq/core/rpc/MemqNettyServer.java
+++ b/memq/src/main/java/com/pinterest/memq/core/rpc/MemqNettyServer.java
@@ -41,6 +41,8 @@
import com.pinterest.memq.core.config.AuthorizerConfig;
import com.pinterest.memq.core.config.MemqConfig;
import com.pinterest.memq.core.config.NettyServerConfig;
+import com.pinterest.memq.core.config.QueueingConfig;
+import com.pinterest.memq.core.rpc.queue.QueuedPacketSwitchingHandler;
import com.pinterest.memq.core.security.Authorizer;
import com.pinterest.memq.core.utils.DaemonThreadFactory;
import com.pinterest.memq.core.utils.MemqUtils;
@@ -137,6 +139,11 @@ public void initialize() throws Exception {
trafficShapingHandler.startPeriodicMetricsReporting(childGroup);
}
+ // Create the packet switch handler ONCE and share across all connections.
+ // This is critical because QueuedPacketSwitchingHandler creates its own thread pool.
+ final PacketSwitchingHandler packetSwitchHandler = createPacketSwitchHandler(
+ nettyServerConfig.getQueueingConfig(), authorizer, registry);
+
if (useEpoll) {
serverBootstrap.channel(EpollServerSocketChannel.class);
} else {
@@ -170,7 +177,7 @@ protected void initChannel(SocketChannel channel) throws Exception {
pipeline.addLast(new LengthFieldBasedFrameDecoder(ByteOrder.BIG_ENDIAN,
nettyServerConfig.getMaxFrameByteLength(), 0, Integer.BYTES, 0, 0, false));
pipeline.addLast(new MemqResponseEncoder(registry));
- pipeline.addLast(new MemqRequestDecoder(memqManager, memqGovernor, authorizer, registry));
+ pipeline.addLast(new MemqRequestDecoder(packetSwitchHandler, registry));
}
});
@@ -229,6 +236,26 @@ private Authorizer enableAuthenticationAuthorizationAuditing(MemqConfig configur
return null;
}
+ /**
+ * Create the packet switch handler. If queueing is enabled, creates a
+ * QueuedPacketSwitchingHandler with its own thread pool; otherwise creates
+ * a regular PacketSwitchingHandler.
+ *
+ * This handler is created ONCE and shared across all connections to avoid
+ * creating a new thread pool per connection.
+ */
+ private PacketSwitchingHandler createPacketSwitchHandler(QueueingConfig queueingConfig,
+ Authorizer authorizer,
+ MetricRegistry registry) throws Exception {
+ if (queueingConfig != null && queueingConfig.isEnabled()) {
+ logger.info("Creating QueuedPacketSwitchingHandler with fair queueing enabled");
+ return new QueuedPacketSwitchingHandler(
+ memqManager, memqGovernor, authorizer, registry, queueingConfig);
+ } else {
+ return new PacketSwitchingHandler(memqManager, memqGovernor, authorizer, registry);
+ }
+ }
+
private EventLoopGroup getEventLoopGroup(int nThreads) {
if (useEpoll) {
logger.info("Epoll is available and will be used");
diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/MemqRequestDecoder.java b/memq/src/main/java/com/pinterest/memq/core/rpc/MemqRequestDecoder.java
index a86f7eb..a6a0781 100644
--- a/memq/src/main/java/com/pinterest/memq/core/rpc/MemqRequestDecoder.java
+++ b/memq/src/main/java/com/pinterest/memq/core/rpc/MemqRequestDecoder.java
@@ -32,9 +32,6 @@
import com.pinterest.memq.commons.protocol.RequestPacket;
import com.pinterest.memq.commons.protocol.ResponseCodes;
import com.pinterest.memq.commons.protocol.ResponsePacket;
-import com.pinterest.memq.core.MemqManager;
-import com.pinterest.memq.core.clustering.MemqGovernor;
-import com.pinterest.memq.core.security.Authorizer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
@@ -45,15 +42,20 @@ public class MemqRequestDecoder extends ChannelInboundHandlerAdapter {
private static final Logger logger = Logger
.getLogger(MemqRequestDecoder.class.getCanonicalName());
- private Counter errorCounter;
- private PacketSwitchingHandler packetSwitchHandler;
+ private final Counter errorCounter;
+ private final PacketSwitchingHandler packetSwitchHandler;
- public MemqRequestDecoder(MemqManager mgr,
- MemqGovernor governor,
- Authorizer authorizer,
+ /**
+ * Create a MemqRequestDecoder with a shared PacketSwitchingHandler.
+ * This constructor should be used to avoid creating a new handler per connection.
+ *
+ * @param packetSwitchHandler the shared handler instance
+ * @param registry the metrics registry
+ */
+ public MemqRequestDecoder(PacketSwitchingHandler packetSwitchHandler,
MetricRegistry registry) {
- errorCounter = registry.counter("request.error");
- packetSwitchHandler = new PacketSwitchingHandler(mgr, governor, authorizer, registry);
+ this.errorCounter = registry.counter("request.error");
+ this.packetSwitchHandler = packetSwitchHandler;
}
@Override
diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/PacketSwitchingHandler.java b/memq/src/main/java/com/pinterest/memq/core/rpc/PacketSwitchingHandler.java
index 4940cec..08e14d6 100644
--- a/memq/src/main/java/com/pinterest/memq/core/rpc/PacketSwitchingHandler.java
+++ b/memq/src/main/java/com/pinterest/memq/core/rpc/PacketSwitchingHandler.java
@@ -25,21 +25,25 @@
import com.codahale.metrics.Counter;
import com.codahale.metrics.MetricRegistry;
+import com.pinterest.memq.commons.protocol.Broker;
import com.pinterest.memq.commons.protocol.ReadRequestPacket;
import com.pinterest.memq.commons.protocol.RequestPacket;
import com.pinterest.memq.commons.protocol.ResponseCodes;
import com.pinterest.memq.commons.protocol.ResponsePacket;
+import com.pinterest.memq.commons.protocol.TopicAssignment;
import com.pinterest.memq.commons.protocol.TopicMetadata;
import com.pinterest.memq.commons.protocol.TopicMetadataRequestPacket;
import com.pinterest.memq.commons.protocol.TopicMetadataResponsePacket;
+import com.pinterest.memq.commons.protocol.TopicConfig;
import com.pinterest.memq.commons.protocol.WriteRequestPacket;
-import com.pinterest.memq.commons.protocol.WriteResponsePacket;
import com.pinterest.memq.commons.protocol.Broker.BrokerType;
import com.pinterest.memq.core.MemqManager;
import com.pinterest.memq.core.clustering.MemqGovernor;
import com.pinterest.memq.core.processing.TopicProcessor;
import com.pinterest.memq.core.security.Authorizer;
+import java.util.Properties;
+
import io.netty.channel.ChannelHandlerContext;
public class PacketSwitchingHandler {
@@ -88,15 +92,19 @@ public void handle(ChannelHandlerContext ctx,
case TOPIC_METADATA:
TopicMetadataRequestPacket mdRequest = (TopicMetadataRequestPacket) requestPacket
.getPayload();
- TopicMetadata md = governor.getTopicMetadataMap().get(mdRequest.getTopic());
+ TopicMetadata md = null;
+ if (!assignmentsEnabled()) {
+ md = buildMetadataForAllBrokers(mdRequest.getTopic());
+ } else {
+ md = governor.getTopicMetadataMap().get(mdRequest.getTopic());
+ }
if (md == null) {
throw TOPIC_NOT_FOUND;
- } else {
- ResponsePacket msg = new ResponsePacket(requestPacket.getProtocolVersion(),
- requestPacket.getClientRequestId(), requestPacket.getRequestType(), ResponseCodes.OK,
- new TopicMetadataResponsePacket(md));
- ctx.writeAndFlush(msg);
}
+ ResponsePacket msg = new ResponsePacket(requestPacket.getProtocolVersion(),
+ requestPacket.getClientRequestId(), requestPacket.getRequestType(), ResponseCodes.OK,
+ new TopicMetadataResponsePacket(md));
+ ctx.writeAndFlush(msg);
break;
case READ:
ReadRequestPacket readPacket = (ReadRequestPacket) requestPacket.getPayload();
@@ -144,6 +152,34 @@ protected void executeWriteRequest(ChannelHandlerContext ctx,
throw SERVER_NOT_INITIALIZED;
}
TopicProcessor topicProcessor = mgr.getProcessorMap().get(writePacket.getTopicName());
+ if (!assignmentsEnabled()) {
+ if (topicProcessor == null) {
+ try {
+ TopicAssignment assignment = mgr.getTopicAssignment(writePacket.getTopicName());
+ if (assignment == null) {
+ TopicConfig topicConfig = governor.getTopicConfig(writePacket.getTopicName());
+ if (topicConfig != null) {
+ assignment = new TopicAssignment(topicConfig, -1);
+ mgr.createTopicProcessor(assignment);
+ }
+ }
+ if (assignment == null) {
+ throw new NotFoundException("Topic not found:" + writePacket.getTopicName());
+ }
+ topicProcessor = mgr.getOrCreateTopicProcessor(writePacket.getTopicName());
+ } catch (NotFoundException e) {
+ logger.severe("Topic not found:" + writePacket.getTopicName());
+ throw TOPIC_NOT_FOUND;
+ } catch (Exception e) {
+ throw new InternalServerErrorException(e);
+ }
+ }
+ writeRquestCounter.inc();
+ mgr.touchTopic(writePacket.getTopicName());
+ topicProcessor.registerChannel(ctx.channel());
+ topicProcessor.write(requestPacket, writePacket, ctx);
+ return;
+ }
if (topicProcessor != null) {
writeRquestCounter.inc();
topicProcessor.registerChannel(ctx.channel());
@@ -168,4 +204,42 @@ protected void executeReadRequest(ChannelHandlerContext ctx,
topicProcessor.read(requestPacket, readPacket, ctx);
}
}
+
+ private boolean assignmentsEnabled() {
+ return mgr.getConfiguration().getClusteringConfig() == null
+ || mgr.getConfiguration().getClusteringConfig().isEnableAssignments();
+ }
+
+ private TopicMetadata buildMetadataForAllBrokers(String topic) throws Exception {
+ TopicMetadata md;
+ TopicAssignment assignment = mgr.getTopicAssignment(topic);
+ if (assignment != null) {
+ md = new TopicMetadata(topic,
+ assignment.getStorageHandlerName(),
+ assignment.getStorageHandlerConfig());
+ } else {
+ TopicConfig topicConfig = governor.getTopicConfig(topic);
+ if (topicConfig != null) {
+ md = new TopicMetadata(topicConfig);
+ } else {
+ // When assignments are disabled, allow metadata for topics not yet known locally.
+ // Storage handler info will be empty; clients can still discover brokers.
+ md = new TopicMetadata(topic, null, new Properties());
+ }
+ }
+ for (Broker broker : governor.getAllBrokers()) {
+ switch (broker.getBrokerType()) {
+ case READ:
+ md.getReadBrokers().add(broker);
+ break;
+ case WRITE:
+ md.getWriteBrokers().add(broker);
+ break;
+ default:
+ md.getReadBrokers().add(broker);
+ md.getWriteBrokers().add(broker);
+ }
+ }
+ return md;
+ }
}
diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/queue/DeficitRoundRobinStrategy.java b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/DeficitRoundRobinStrategy.java
new file mode 100644
index 0000000..94a9c74
--- /dev/null
+++ b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/DeficitRoundRobinStrategy.java
@@ -0,0 +1,315 @@
+/**
+ * Copyright 2022 Pinterest, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.pinterest.memq.core.rpc.queue;
+
+import com.pinterest.memq.core.config.QueueingConfig;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.logging.Logger;
+
+/**
+ * Deficit Round Robin (DRR) implementation of QueueingStrategy.
+ *
+ * DRR is a fair scheduling algorithm that provides weighted fair queuing.
+ * Each queue maintains a deficit counter. When a queue is visited:
+ * - Its deficit counter is incremented by a quantum value
+ * - Requests are dequeued as long as their size is less than or equal to the deficit counter
+ * - The request's size is subtracted from the deficit counter
+ *
+ * This ensures fair bandwidth allocation across all topics while being
+ * work-conserving (idle queues don't waste capacity).
+ *
+ * Strategy-specific configuration (in strategyConfig map):
+ * - "quantum": int - bytes per round per queue (default: 64KB)
+ */
+public class DeficitRoundRobinStrategy extends QueueingStrategy {
+
+ private static final Logger logger = Logger.getLogger(DeficitRoundRobinStrategy.class.getName());
+
+ /**
+ * Configuration key for the quantum value in strategyConfig.
+ */
+ public static final String CONFIG_KEY_QUANTUM = "quantum";
+
+ /**
+ * Default quantum size in bytes (64KB).
+ * This determines how many bytes worth of requests can be processed
+ * per round for each queue.
+ */
+ public static final int DEFAULT_QUANTUM = 65536;
+
+ /**
+ * Per-topic queue holding pending requests.
+ */
+ private final ConcurrentHashMap> queues = new ConcurrentHashMap<>();
+
+ /**
+ * Current byte count for each topic queue.
+ */
+ private final ConcurrentHashMap queueBytes = new ConcurrentHashMap<>();
+
+ /**
+ * Deficit counter for each topic (used by DRR algorithm).
+ */
+ private final ConcurrentHashMap deficitCounters = new ConcurrentHashMap<>();
+
+ /**
+ * Round-robin state for each dequeue thread - tracks current position in topic list.
+ */
+ private final ConcurrentHashMap> threadIterators = new ConcurrentHashMap<>();
+
+ /**
+ * Cached list of topics for round-robin iteration.
+ */
+ private final ConcurrentHashMap> threadTopicLists = new ConcurrentHashMap<>();
+
+ private int quantum;
+ private long maxQueueBytes;
+
+ @Override
+ public void init(QueueingConfig config) {
+ super.init(config);
+ this.quantum = getQuantumFromConfig(config);
+ this.maxQueueBytes = config.getMaxQueueBytesPerTopic();
+ }
+
+ /**
+ * Extract the quantum value from strategy config, or use default.
+ */
+ private int getQuantumFromConfig(QueueingConfig config) {
+ Object quantumObj = config.getStrategyConfig().get(CONFIG_KEY_QUANTUM);
+ if (quantumObj != null) {
+ if (quantumObj instanceof Number) {
+ return ((Number) quantumObj).intValue();
+ }
+ if (quantumObj instanceof String) {
+ try {
+ return Integer.parseInt((String) quantumObj);
+ } catch (NumberFormatException e) {
+ logger.warning("Invalid quantum value in config: " + quantumObj + ", using default");
+ }
+ }
+ }
+ return DEFAULT_QUANTUM;
+ }
+
+ /**
+ * Get the current quantum value being used.
+ *
+ * @return the quantum in bytes
+ */
+ public int getQuantum() {
+ return quantum;
+ }
+
+ @Override
+ public boolean enqueue(QueuedRequest request) {
+ String topicName = request.getTopicName();
+ int requestSize = Math.max(1, request.getSize());
+
+ // Get or create queue and byte counter
+ ConcurrentLinkedQueue queue = queues.computeIfAbsent(topicName,
+ k -> new ConcurrentLinkedQueue<>());
+ AtomicLong bytes = queueBytes.computeIfAbsent(topicName, k -> new AtomicLong(0));
+ deficitCounters.computeIfAbsent(topicName, k -> new AtomicInteger(0));
+
+ // Check if adding this request would exceed the byte limit
+ // Use CAS loop for thread safety
+ while (true) {
+ long currentBytes = bytes.get();
+ if (currentBytes + requestSize > maxQueueBytes) {
+ logger.warning("Queue full for topic: " + topicName +
+ ", current bytes: " + currentBytes + ", request size: " + requestSize +
+ ", max: " + maxQueueBytes + ", rejecting request");
+ return false;
+ }
+ if (bytes.compareAndSet(currentBytes, currentBytes + requestSize)) {
+ queue.offer(request);
+ return true;
+ }
+ // CAS failed, retry
+ }
+ }
+
+ @Override
+ public QueuedRequest dequeue(Set assignedTopics) {
+ if (assignedTopics.isEmpty()) {
+ return null;
+ }
+
+ long threadId = Thread.currentThread().getId();
+
+ // Get or create the topic list for this thread
+ List topicList = threadTopicLists.computeIfAbsent(threadId, k -> new ArrayList<>());
+
+ // Refresh the topic list if needed (topics may have been added/removed)
+ refreshTopicListIfNeeded(topicList, assignedTopics);
+
+ if (topicList.isEmpty()) {
+ return null;
+ }
+
+ // Get or create iterator for this thread
+ Iterator iterator = threadIterators.get(threadId);
+ if (iterator == null || !iterator.hasNext()) {
+ iterator = topicList.iterator();
+ threadIterators.put(threadId, iterator);
+ }
+
+ // Try each topic in round-robin fashion
+ int topicsChecked = 0;
+ int totalTopics = topicList.size();
+
+ while (topicsChecked < totalTopics) {
+ if (!iterator.hasNext()) {
+ iterator = topicList.iterator();
+ threadIterators.put(threadId, iterator);
+ }
+
+ String topicName = iterator.next();
+ topicsChecked++;
+
+ ConcurrentLinkedQueue queue = queues.get(topicName);
+ if (queue == null || queue.isEmpty()) {
+ // Reset deficit when queue is empty (work-conserving)
+ AtomicInteger deficit = deficitCounters.get(topicName);
+ if (deficit != null) {
+ deficit.set(0);
+ }
+ continue;
+ }
+
+ AtomicInteger deficit = deficitCounters.get(topicName);
+ if (deficit == null) {
+ continue;
+ }
+
+ // Peek at the next request
+ QueuedRequest request = queue.peek();
+ if (request == null) {
+ continue;
+ }
+
+ int requestSize = Math.max(1, request.getSize()); // Minimum size of 1 to prevent starvation
+ int currentDeficit = deficit.get();
+
+ // Add quantum to deficit
+ if (currentDeficit < requestSize) {
+ currentDeficit = deficit.addAndGet(quantum);
+ }
+
+ // Check if we can service this request
+ if (currentDeficit >= requestSize) {
+ // Try to poll the request
+ request = queue.poll();
+ if (request != null) {
+ // Subtract request size from deficit and update byte counter
+ deficit.addAndGet(-requestSize);
+ AtomicLong bytes = queueBytes.get(topicName);
+ if (bytes != null) {
+ bytes.addAndGet(-requestSize);
+ }
+ return request;
+ }
+ }
+ }
+
+ return null;
+ }
+
+ /**
+ * Refresh the topic list if the assigned topics have changed.
+ */
+ private void refreshTopicListIfNeeded(List topicList, Set assignedTopics) {
+ // Check if we need to refresh
+ Set currentTopics = new HashSet<>(topicList);
+ Set activeAssigned = new HashSet<>();
+
+ for (String topic : assignedTopics) {
+ if (queues.containsKey(topic)) {
+ activeAssigned.add(topic);
+ }
+ }
+
+ if (!currentTopics.equals(activeAssigned)) {
+ topicList.clear();
+ topicList.addAll(activeAssigned);
+ // Reset iterator since list changed
+ threadIterators.remove(Thread.currentThread().getId());
+ }
+ }
+
+ @Override
+ public long getPendingBytes() {
+ long total = 0;
+ for (AtomicLong bytes : queueBytes.values()) {
+ total += bytes.get();
+ }
+ return total;
+ }
+
+ @Override
+ public long getPendingBytes(String topicName) {
+ AtomicLong bytes = queueBytes.get(topicName);
+ return bytes != null ? bytes.get() : 0;
+ }
+
+ /**
+ * Get the current number of pending requests across all queues.
+ *
+ * @return total number of pending requests
+ */
+ public int getPendingCount() {
+ int total = 0;
+ for (ConcurrentLinkedQueue queue : queues.values()) {
+ total += queue.size();
+ }
+ return total;
+ }
+
+ /**
+ * Get the current number of pending requests for a specific topic.
+ *
+ * @param topicName the topic name
+ * @return number of pending requests for the topic
+ */
+ public int getPendingCount(String topicName) {
+ ConcurrentLinkedQueue queue = queues.get(topicName);
+ return queue != null ? queue.size() : 0;
+ }
+
+ @Override
+ public Set getActiveTopics() {
+ return new HashSet<>(queues.keySet());
+ }
+
+ @Override
+ public void shutdown() {
+ queues.clear();
+ queueBytes.clear();
+ deficitCounters.clear();
+ threadIterators.clear();
+ threadTopicLists.clear();
+ }
+}
diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java
new file mode 100644
index 0000000..0e5aaee
--- /dev/null
+++ b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java
@@ -0,0 +1,298 @@
+/**
+ * Copyright 2022 Pinterest, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.pinterest.memq.core.rpc.queue;
+
+import java.security.Principal;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+import javax.ws.rs.ServiceUnavailableException;
+
+import com.codahale.metrics.Gauge;
+import com.codahale.metrics.MetricRegistry;
+import com.pinterest.memq.commons.protocol.RequestPacket;
+import com.pinterest.memq.commons.protocol.RequestType;
+import com.pinterest.memq.commons.protocol.WriteRequestPacket;
+import com.pinterest.memq.core.MemqManager;
+import com.pinterest.memq.core.clustering.MemqGovernor;
+import com.pinterest.memq.core.config.QueueingConfig;
+import com.pinterest.memq.core.rpc.PacketSwitchingHandler;
+import com.pinterest.memq.core.security.Authorizer;
+
+import io.netty.channel.ChannelHandlerContext;
+
+/**
+ * A PacketSwitchingHandler that implements fair queueing for write requests.
+ *
+ * Write requests are enqueued into topic-specific queues and processed by
+ * a configurable thread pool using a fair scheduling algorithm (e.g., DRR).
+ *
+ * Read and metadata requests are processed directly without queueing.
+ */
+public class QueuedPacketSwitchingHandler extends PacketSwitchingHandler {
+
+ private static final Logger logger = Logger.getLogger(QueuedPacketSwitchingHandler.class.getName());
+
+ private final QueueingStrategy strategy;
+ private final ExecutorService dequeueExecutor;
+ private final AtomicBoolean running = new AtomicBoolean(true);
+ private final ConcurrentHashMap perTopicGaugeRegistered = new ConcurrentHashMap<>();
+
+ /**
+ * Number of dequeue threads.
+ * Topics are distributed across threads using consistent hashing on topic name.
+ */
+ private final int numDequeueThreads;
+
+ public QueuedPacketSwitchingHandler(MemqManager mgr,
+ MemqGovernor governor,
+ Authorizer authorizer,
+ MetricRegistry registry,
+ QueueingConfig queueingConfig) throws Exception {
+ super(mgr, governor, authorizer, registry);
+ this.numDequeueThreads = queueingConfig.getDequeueThreadPoolSize();
+
+ // Initialize the queueing strategy first
+ this.strategy = createStrategy(queueingConfig);
+
+ // Initialize the dequeue thread pool with named daemon threads
+ this.dequeueExecutor = Executors.newFixedThreadPool(
+ numDequeueThreads,
+ new ThreadFactory() {
+ private final AtomicInteger threadNumber = new AtomicInteger(1);
+ @Override
+ public Thread newThread(Runnable r) {
+ Thread t = new Thread(r, "memq-dequeue-" + threadNumber.getAndIncrement());
+ t.setDaemon(true);
+ return t;
+ }
+ });
+
+ // Start the dequeue worker threads
+ for (int i = 0; i < numDequeueThreads; i++) {
+ final int threadIndex = i;
+ dequeueExecutor.submit(() -> dequeueLoop(threadIndex));
+ }
+
+ logger.info("QueuedPacketSwitchingHandler initialized with strategy: " +
+ queueingConfig.getStrategyClass() + ", threads: " + numDequeueThreads);
+ }
+
+ /**
+ * Create a QueueingStrategy instance from configuration.
+ */
+ @SuppressWarnings("unchecked")
+ private QueueingStrategy createStrategy(QueueingConfig config) throws Exception {
+ Class extends QueueingStrategy> strategyClass =
+ (Class extends QueueingStrategy>) Class.forName(config.getStrategyClass());
+ QueueingStrategy strategy = strategyClass.getDeclaredConstructor().newInstance();
+ strategy.init(config);
+ return strategy;
+ }
+
+ @Override
+ public void handle(ChannelHandlerContext ctx,
+ RequestPacket requestPacket,
+ Principal principal,
+ String clientAddress) throws Exception {
+ // Only queue WRITE requests
+ if (requestPacket.getRequestType() == RequestType.WRITE) {
+ WriteRequestPacket writePacket = (WriteRequestPacket) requestPacket.getPayload();
+ registerPerTopicMetricsIfNeeded(writePacket.getTopicName());
+
+ // Retain the ByteBuf since it will be released by MemqRequestDecoder's finally block
+ // but we need it to stay alive until the request is dequeued and processed.
+ // The buffer will be released in processQueuedRequest after processing.
+ writePacket.getData().retain();
+
+ // Create queued request
+ QueuedRequest queuedRequest = new QueuedRequest(ctx, requestPacket, writePacket);
+
+ // Try to enqueue
+ boolean enqueued = strategy.enqueue(queuedRequest);
+
+ if (enqueued) {
+ MetricRegistry topicRegistry = mgr.getRegistry().get(queuedRequest.getTopicName());
+ if (topicRegistry != null) {
+ topicRegistry.counter("queue.enqueued.bytes").inc(queuedRequest.getSize());
+ }
+ } else {
+ // Release the buffer since we retained it but failed to enqueue
+ writePacket.getData().release();
+ MetricRegistry topicRegistry = mgr.getRegistry().get(queuedRequest.getTopicName());
+ if (topicRegistry != null) {
+ topicRegistry.counter("queue.rejected.bytes").inc(queuedRequest.getSize());
+ }
+ // TODO: we should return a response packet with an error code signalling congestion
+ throw new ServiceUnavailableException("Queue full for topic: " + writePacket.getTopicName());
+ }
+ } else {
+ // Process non-write requests directly (metadata, read requests)
+ super.handle(ctx, requestPacket, principal, clientAddress);
+ }
+ }
+
+ /**
+ * Main dequeue loop for each worker thread.
+ * Each thread is responsible for processing a subset of topics.
+ */
+ private void dequeueLoop(int threadIndex) {
+ logger.info("Dequeue worker thread " + threadIndex + " started");
+
+ while (running.get()) {
+ try {
+ // Get topics assigned to this thread
+ Set assignedTopics = getAssignedTopics(threadIndex);
+
+ // Try to dequeue a request
+ QueuedRequest request = strategy.dequeue(assignedTopics);
+
+ if (request != null) {
+ registerPerTopicMetricsIfNeeded(request.getTopicName());
+ long queueLatencyNanos = System.nanoTime() - request.getEnqueueTimeNanos();
+ MetricRegistry topicRegistry = mgr.getRegistry().get(request.getTopicName());
+ if (topicRegistry != null) {
+ topicRegistry.timer("queue.latency.nanos").update(queueLatencyNanos, TimeUnit.NANOSECONDS);
+ }
+ if (topicRegistry != null) {
+ topicRegistry.counter("queue.dequeued.bytes").inc(request.getSize());
+ }
+ processQueuedRequest(request);
+ } else {
+ // No work available, sleep briefly to avoid spinning
+ // TODO: event-driven approach to avoid spinning
+ Thread.sleep(1);
+ }
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ break;
+ } catch (Exception e) {
+ logger.log(Level.SEVERE, "Error in dequeue loop", e);
+ }
+ }
+
+ logger.info("Dequeue worker thread " + threadIndex + " stopped");
+ }
+
+ /**
+ * Get the set of topics assigned to a specific thread index.
+ * Uses consistent hashing to distribute topics across threads.
+ */
+ private Set getAssignedTopics(int threadIndex) {
+ Set assignedTopics = new HashSet<>();
+
+ for (String topic : strategy.getActiveTopics()) {
+ int assignedThread = Math.abs(topic.hashCode()) % numDequeueThreads;
+ if (assignedThread == threadIndex) {
+ assignedTopics.add(topic);
+ }
+ }
+
+ return assignedTopics;
+ }
+
+ /**
+ * Process a dequeued request by calling the parent's executeWriteRequest.
+ */
+ private void processQueuedRequest(QueuedRequest request) {
+ try {
+ executeWriteRequest(request.getCtx(), request.getRequestPacket(), request.getWritePacket());
+ } catch (Exception e) {
+ logger.log(Level.SEVERE, "Error processing queued request for topic: " +
+ request.getTopicName(), e);
+ // The channel context may need error handling
+ handleProcessingError(request, e);
+ } finally {
+ // Release the ByteBuf that was retained when the request was enqueued.
+ // This balances the retain() call in handle().
+ try {
+ request.getWritePacket().getData().release();
+ } catch (Exception e) {
+ logger.log(Level.WARNING, "Failed to release buffer for request", e);
+ }
+ }
+ }
+
+ /**
+ * Handle errors that occur during request processing.
+ */
+ private void handleProcessingError(QueuedRequest request, Exception e) {
+ try {
+ ChannelHandlerContext ctx = request.getCtx();
+ if (ctx.channel().isActive()) {
+ // Error will be handled by the exception handler in the pipeline
+ ctx.fireExceptionCaught(e);
+ }
+ } catch (Exception ex) {
+ logger.log(Level.WARNING, "Failed to handle processing error", ex);
+ }
+ }
+
+ private void registerPerTopicMetricsIfNeeded(String topicName) {
+ MetricRegistry topicRegistry = mgr.getRegistry().get(topicName);
+ if (topicRegistry == null) {
+ return;
+ }
+ if (perTopicGaugeRegistered.putIfAbsent(topicName, Boolean.TRUE) != null) {
+ return;
+ }
+ topicRegistry.gauge("queue.pending.bytes",
+ () -> (Gauge) () -> strategy.getPendingBytes(topicName));
+ }
+
+ /**
+ * Shutdown the handler and its thread pool.
+ */
+ public void shutdown() {
+ running.set(false);
+ strategy.shutdown();
+
+ dequeueExecutor.shutdown();
+ try {
+ if (!dequeueExecutor.awaitTermination(30, TimeUnit.SECONDS)) {
+ dequeueExecutor.shutdownNow();
+ }
+ } catch (InterruptedException e) {
+ dequeueExecutor.shutdownNow();
+ Thread.currentThread().interrupt();
+ }
+
+ logger.info("QueuedPacketSwitchingHandler shutdown complete");
+ }
+
+ /**
+ * Get the queueing strategy being used.
+ */
+ public QueueingStrategy getStrategy() {
+ return strategy;
+ }
+
+ /**
+ * Get the pending bytes across all queues.
+ */
+ public long getPendingBytes() {
+ return strategy.getPendingBytes();
+ }
+}
diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedRequest.java b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedRequest.java
new file mode 100644
index 0000000..c24daa4
--- /dev/null
+++ b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedRequest.java
@@ -0,0 +1,73 @@
+/**
+ * Copyright 2022 Pinterest, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.pinterest.memq.core.rpc.queue;
+
+import com.pinterest.memq.commons.protocol.RequestPacket;
+import com.pinterest.memq.commons.protocol.WriteRequestPacket;
+
+import io.netty.channel.ChannelHandlerContext;
+
+/**
+ * Wrapper class that holds all the context needed to process a queued request.
+ */
+public class QueuedRequest {
+
+ private final ChannelHandlerContext ctx;
+ private final RequestPacket requestPacket;
+ private final WriteRequestPacket writePacket;
+ private final String topicName;
+ private final int size;
+ private final long enqueueTimeNanos;
+
+ public QueuedRequest(ChannelHandlerContext ctx,
+ RequestPacket requestPacket,
+ WriteRequestPacket writePacket) {
+ this.ctx = ctx;
+ this.requestPacket = requestPacket;
+ this.writePacket = writePacket;
+ this.topicName = writePacket.getTopicName();
+ this.size = writePacket.getDataLength();
+ this.enqueueTimeNanos = System.nanoTime();
+ }
+
+ public ChannelHandlerContext getCtx() {
+ return ctx;
+ }
+
+ public RequestPacket getRequestPacket() {
+ return requestPacket;
+ }
+
+ public WriteRequestPacket getWritePacket() {
+ return writePacket;
+ }
+
+ public String getTopicName() {
+ return topicName;
+ }
+
+ /**
+ * Returns the size of the request in bytes.
+ * Used by strategies like DRR that need to track byte counts.
+ */
+ public int getSize() {
+ return size;
+ }
+
+ public long getEnqueueTimeNanos() {
+ return enqueueTimeNanos;
+ }
+}
diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueueingStrategy.java b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueueingStrategy.java
new file mode 100644
index 0000000..5f39861
--- /dev/null
+++ b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueueingStrategy.java
@@ -0,0 +1,91 @@
+/**
+ * Copyright 2022 Pinterest, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.pinterest.memq.core.rpc.queue;
+
+import com.pinterest.memq.core.config.QueueingConfig;
+
+import java.util.Set;
+
+/**
+ * Abstract base class for fair queueing strategies.
+ *
+ * Implementations of this class provide different algorithms for
+ * scheduling requests across multiple topic queues in a fair manner.
+ */
+public abstract class QueueingStrategy {
+
+ protected QueueingConfig config;
+
+ /**
+ * Initialize the strategy with configuration.
+ * Called once after instantiation.
+ *
+ * @param config the queueing configuration
+ */
+ public void init(QueueingConfig config) {
+ this.config = config;
+ }
+
+ /**
+ * Enqueue a request into the appropriate topic queue.
+ *
+ * @param request the queued request to enqueue
+ * @return true if the request was successfully enqueued, false if rejected
+ * (e.g., due to queue being full)
+ */
+ public abstract boolean enqueue(QueuedRequest request);
+
+ /**
+ * Dequeue the next request to be processed according to the strategy's
+ * fair scheduling algorithm.
+ *
+ * This method should only process queues that are assigned to the calling thread.
+ *
+ * @param assignedTopics the set of topic names that this thread is responsible for
+ * @return the next request to process, or null if no requests are available
+ */
+ public abstract QueuedRequest dequeue(Set assignedTopics);
+
+ /**
+ * Get the current pending bytes across all queues.
+ *
+ * @return total pending bytes
+ */
+ public abstract long getPendingBytes();
+
+ /**
+ * Get the current pending bytes for a specific topic.
+ *
+ * @param topicName the topic name
+ * @return pending bytes for the topic
+ */
+ public abstract long getPendingBytes(String topicName);
+
+ /**
+ * Get all topic names that currently have queues (even if empty).
+ *
+ * @return set of topic names
+ */
+ public abstract Set getActiveTopics();
+
+ /**
+ * Called when the strategy is being shut down.
+ * Implementations should clean up any resources.
+ */
+ public void shutdown() {
+ // Default no-op implementation
+ }
+}
diff --git a/memq/src/test/java/com/pinterest/memq/core/TestMemqManager.java b/memq/src/test/java/com/pinterest/memq/core/TestMemqManager.java
index 5ef21a4..5932ba4 100644
--- a/memq/src/test/java/com/pinterest/memq/core/TestMemqManager.java
+++ b/memq/src/test/java/com/pinterest/memq/core/TestMemqManager.java
@@ -16,15 +16,18 @@
package com.pinterest.memq.core;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
import java.io.File;
import java.nio.file.Paths;
import java.util.HashMap;
+import java.util.concurrent.TimeUnit;
import org.junit.Test;
import com.pinterest.memq.commons.protocol.TopicAssignment;
import com.pinterest.memq.commons.protocol.TopicConfig;
+import com.pinterest.memq.core.config.ClusteringConfig;
import com.pinterest.memq.core.config.MemqConfig;
public class TestMemqManager {
@@ -67,4 +70,41 @@ public void testTopicConfig() throws Exception {
}
+ @Test
+ public void testAssignmentsDisabledLazyCreationAndIdleCleanup() throws Exception {
+ MemqConfig configuration = new MemqConfig();
+ ClusteringConfig clusteringConfig = new ClusteringConfig();
+ clusteringConfig.setEnableAssignments(false);
+ clusteringConfig.setMaxIdleMs(50);
+ configuration.setClusteringConfig(clusteringConfig);
+ File tmpFile = File.createTempFile("testmgrcache", "", Paths.get("/tmp").toFile());
+ tmpFile.deleteOnExit();
+ configuration.setTopicCacheFile(tmpFile.toString());
+ tmpFile.delete();
+ configuration.setTopicConfig(new TopicConfig[] {
+ new TopicConfig("test", "delayeddevnull")
+ });
+
+ MemqManager mgr = new MemqManager(null, configuration, new HashMap<>());
+ mgr.init();
+
+ // No processors should be created on init when assignments are disabled.
+ assertEquals(0, mgr.getProcessorMap().size());
+ assertEquals(1, mgr.getTopicAssignment().size());
+
+ // Create on first use.
+ assertTrue(mgr.getOrCreateTopicProcessor("test") != null);
+ assertEquals(1, mgr.getProcessorMap().size());
+
+ // Touch the topic and start idle cleanup with a short timeout.
+ mgr.touchTopic("test");
+ mgr.startIdleTopicCleanup(25);
+
+ long deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(2);
+ while (System.currentTimeMillis() < deadline && mgr.getProcessorMap().size() > 0) {
+ Thread.sleep(20);
+ }
+ assertEquals(0, mgr.getProcessorMap().size());
+ }
+
}
diff --git a/memq/src/test/java/com/pinterest/memq/core/clustering/TestMemqGovernor.java b/memq/src/test/java/com/pinterest/memq/core/clustering/TestMemqGovernor.java
index 8b80d6e..b83b023 100644
--- a/memq/src/test/java/com/pinterest/memq/core/clustering/TestMemqGovernor.java
+++ b/memq/src/test/java/com/pinterest/memq/core/clustering/TestMemqGovernor.java
@@ -19,13 +19,20 @@
import java.io.FileNotFoundException;
import java.io.FileReader;
+import java.util.HashMap;
+import java.util.Set;
import org.junit.Test;
import com.google.gson.Gson;
import com.google.gson.JsonIOException;
import com.google.gson.JsonSyntaxException;
+import com.pinterest.memq.commons.protocol.Broker;
import com.pinterest.memq.commons.protocol.TopicConfig;
+import com.pinterest.memq.core.MemqManager;
+import com.pinterest.memq.core.config.ClusteringConfig;
+import com.pinterest.memq.core.config.LocalEnvironmentProvider;
+import com.pinterest.memq.core.config.MemqConfig;
public class TestMemqGovernor {
@@ -41,4 +48,17 @@ public void testBackwardsCompatibility() throws JsonSyntaxException, JsonIOExcep
assertEquals("customs3aync2", newConf.getStorageHandlerName());
}
+ @Test
+ public void testGetAllBrokersWithNoZkClient() throws Exception {
+ MemqConfig config = new MemqConfig();
+ ClusteringConfig clusteringConfig = new ClusteringConfig();
+ clusteringConfig.setEnableAssignments(false);
+ config.setClusteringConfig(clusteringConfig);
+ MemqManager mgr = new MemqManager(null, config, new HashMap<>());
+ MemqGovernor governor = new MemqGovernor(mgr, config, new LocalEnvironmentProvider());
+
+ Set brokers = governor.getAllBrokers();
+ assertEquals(0, brokers.size());
+ }
+
}
diff --git a/memq/src/test/java/com/pinterest/memq/core/integration/TestMemqClientServerIntegration.java b/memq/src/test/java/com/pinterest/memq/core/integration/TestMemqClientServerIntegration.java
index e776eda..f4618b3 100644
--- a/memq/src/test/java/com/pinterest/memq/core/integration/TestMemqClientServerIntegration.java
+++ b/memq/src/test/java/com/pinterest/memq/core/integration/TestMemqClientServerIntegration.java
@@ -95,6 +95,7 @@
import com.pinterest.memq.core.rpc.MemqNettyServer;
import com.pinterest.memq.core.rpc.MemqRequestDecoder;
import com.pinterest.memq.core.rpc.MemqResponseEncoder;
+import com.pinterest.memq.core.rpc.PacketSwitchingHandler;
import com.pinterest.memq.core.rpc.TestAuditor;
import com.pinterest.memq.core.utils.MiscUtils;
import com.salesforce.kafka.test.junit4.SharedKafkaTestResource;
@@ -141,7 +142,7 @@ public void testProducerAndServer() throws IOException {
MetricRegistry registry = new MetricRegistry();
EmbeddedChannel ech = new EmbeddedChannel(new MemqResponseEncoder(registry),
new LengthFieldBasedFrameDecoder(ByteOrder.BIG_ENDIAN, 2 * 1024 * 1024, 0, 4, 0, 0, false),
- new MemqRequestDecoder(null, null, null, registry));
+ new MemqRequestDecoder(new PacketSwitchingHandler(null, null, null, registry), registry));
ech.writeInbound(output);
ech.checkException();
assertEquals(1, ech.outboundMessages().size());
diff --git a/memq/src/test/java/com/pinterest/memq/core/rpc/TestPacketSwitchingHandler.java b/memq/src/test/java/com/pinterest/memq/core/rpc/TestPacketSwitchingHandler.java
new file mode 100644
index 0000000..a0809ea
--- /dev/null
+++ b/memq/src/test/java/com/pinterest/memq/core/rpc/TestPacketSwitchingHandler.java
@@ -0,0 +1,263 @@
+/**
+ * Copyright 2022 Pinterest, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.pinterest.memq.core.rpc;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.fail;
+
+import java.io.File;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Set;
+
+import org.junit.Test;
+
+import com.codahale.metrics.MetricRegistry;
+import com.pinterest.memq.commons.protocol.Broker;
+import com.pinterest.memq.commons.protocol.Broker.BrokerType;
+import com.pinterest.memq.commons.protocol.RequestPacket;
+import com.pinterest.memq.commons.protocol.RequestType;
+import com.pinterest.memq.commons.protocol.ResponsePacket;
+import com.pinterest.memq.commons.protocol.TopicMetadata;
+import com.pinterest.memq.commons.protocol.TopicMetadataRequestPacket;
+import com.pinterest.memq.commons.protocol.TopicMetadataResponsePacket;
+import com.pinterest.memq.commons.protocol.TopicConfig;
+import com.pinterest.memq.commons.protocol.WriteRequestPacket;
+import com.pinterest.memq.core.MemqManager;
+import com.pinterest.memq.core.clustering.MemqGovernor;
+import com.pinterest.memq.core.config.ClusteringConfig;
+import com.pinterest.memq.core.config.LocalEnvironmentProvider;
+import com.pinterest.memq.core.config.MemqConfig;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.embedded.EmbeddedChannel;
+
+public class TestPacketSwitchingHandler {
+
+ private static class TestGovernor extends MemqGovernor {
+ private final Set brokers;
+
+ public TestGovernor(MemqManager mgr,
+ MemqConfig config,
+ LocalEnvironmentProvider provider,
+ Set brokers) {
+ super(mgr, config, provider);
+ this.brokers = brokers;
+ }
+
+ @Override
+ public Set getAllBrokers() {
+ return brokers;
+ }
+ }
+
+ private static class TestGovernorWithTopicConfig extends TestGovernor {
+ private final TopicConfig topicConfig;
+
+ public TestGovernorWithTopicConfig(MemqManager mgr,
+ MemqConfig config,
+ LocalEnvironmentProvider provider,
+ Set brokers,
+ TopicConfig topicConfig) {
+ super(mgr, config, provider, brokers);
+ this.topicConfig = topicConfig;
+ }
+
+ @Override
+ public TopicConfig getTopicConfig(String topic) {
+ if (topicConfig != null && topicConfig.getTopic().equals(topic)) {
+ return topicConfig;
+ }
+ return null;
+ }
+ }
+
+ private static class PacketSwitchingAdapter extends ChannelInboundHandlerAdapter {
+ private final PacketSwitchingHandler handler;
+
+ public PacketSwitchingAdapter(PacketSwitchingHandler handler) {
+ this.handler = handler;
+ }
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+ handler.handle(ctx, (RequestPacket) msg, null, null);
+ }
+ }
+
+ @Test
+ public void testMetadataAllBrokersWhenAssignmentsDisabled() throws Exception {
+ MemqConfig config = new MemqConfig();
+ ClusteringConfig clusteringConfig = new ClusteringConfig();
+ clusteringConfig.setEnableAssignments(false);
+ config.setClusteringConfig(clusteringConfig);
+ File tmpFile = File.createTempFile("testmgrcache", "", Paths.get("/tmp").toFile());
+ tmpFile.deleteOnExit();
+ config.setTopicCacheFile(tmpFile.toString());
+ config.setTopicConfig(new TopicConfig[] {
+ new TopicConfig("test", "delayeddevnull")
+ });
+
+ MemqManager mgr = new MemqManager(null, config, new HashMap<>());
+ mgr.init();
+
+ Set brokers = new HashSet<>(Arrays.asList(
+ new Broker("1.1.1.1", (short) 9092, "test", "local", BrokerType.READ_WRITE, new HashSet<>()),
+ new Broker("1.1.1.2", (short) 9092, "test", "local", BrokerType.READ_WRITE, new HashSet<>())
+ ));
+
+ MemqGovernor governor = new TestGovernor(mgr, config, new LocalEnvironmentProvider(), brokers);
+ PacketSwitchingHandler handler = new PacketSwitchingHandler(mgr, governor, null, new MetricRegistry());
+
+ EmbeddedChannel channel = new EmbeddedChannel(new PacketSwitchingAdapter(handler));
+ RequestPacket request = new RequestPacket((short) 0, 1L, RequestType.TOPIC_METADATA,
+ new TopicMetadataRequestPacket("test"));
+ channel.writeInbound(request);
+
+ ResponsePacket response = channel.readOutbound();
+ assertNotNull(response);
+ assertEquals(RequestType.TOPIC_METADATA, response.getRequestType());
+
+ TopicMetadataResponsePacket payload = (TopicMetadataResponsePacket) response.getPacket();
+ TopicMetadata metadata = payload.getMetadata();
+ assertNotNull(metadata);
+ assertEquals(2, metadata.getReadBrokers().size());
+ assertEquals(2, metadata.getWriteBrokers().size());
+ }
+
+ @Test
+ public void testUnknownTopicMetadataAssignmentsEnabled() throws Exception {
+ MemqConfig config = new MemqConfig();
+ ClusteringConfig clusteringConfig = new ClusteringConfig();
+ clusteringConfig.setEnableAssignments(true);
+ config.setClusteringConfig(clusteringConfig);
+ config.setTopicConfig(new TopicConfig[] {
+ new TopicConfig("known", "delayeddevnull")
+ });
+
+ MemqManager mgr = new MemqManager(null, config, new HashMap<>());
+ mgr.init();
+
+ MemqGovernor governor = new TestGovernor(mgr, config, new LocalEnvironmentProvider(),
+ Collections.emptySet());
+ PacketSwitchingHandler handler = new PacketSwitchingHandler(mgr, governor, null, new MetricRegistry());
+
+ EmbeddedChannel channel = new EmbeddedChannel(new PacketSwitchingAdapter(handler));
+ RequestPacket request = new RequestPacket((short) 0, 1L, RequestType.TOPIC_METADATA,
+ new TopicMetadataRequestPacket("unknown"));
+ try {
+ channel.writeInbound(request);
+ fail("Expected TOPIC_NOT_FOUND for unknown topic when assignments are enabled");
+ } catch (Exception e) {
+ assertEquals(PacketSwitchingHandler.TOPIC_NOT_FOUND, e);
+ }
+ }
+
+ @Test
+ public void testUnknownTopicMetadataAssignmentsDisabled() throws Exception {
+ MemqConfig config = new MemqConfig();
+ ClusteringConfig clusteringConfig = new ClusteringConfig();
+ clusteringConfig.setEnableAssignments(false);
+ config.setClusteringConfig(clusteringConfig);
+
+ MemqManager mgr = new MemqManager(null, config, new HashMap<>());
+ mgr.init();
+
+ Set brokers = new HashSet<>(Arrays.asList(
+ new Broker("1.1.1.1", (short) 9092, "test", "local", BrokerType.READ_WRITE, new HashSet<>())
+ ));
+
+ MemqGovernor governor = new TestGovernor(mgr, config, new LocalEnvironmentProvider(), brokers);
+ PacketSwitchingHandler handler = new PacketSwitchingHandler(mgr, governor, null, new MetricRegistry());
+
+ EmbeddedChannel channel = new EmbeddedChannel(new PacketSwitchingAdapter(handler));
+ RequestPacket request = new RequestPacket((short) 0, 1L, RequestType.TOPIC_METADATA,
+ new TopicMetadataRequestPacket("unknown"));
+ channel.writeInbound(request);
+
+ ResponsePacket response = channel.readOutbound();
+ assertNotNull(response);
+ TopicMetadataResponsePacket payload = (TopicMetadataResponsePacket) response.getPacket();
+ TopicMetadata metadata = payload.getMetadata();
+ assertEquals("unknown", metadata.getTopicName());
+ assertEquals(1, metadata.getWriteBrokers().size());
+ assertEquals(1, metadata.getReadBrokers().size());
+ }
+
+ @Test
+ public void testUnknownTopicWriteAssignmentsDisabled() throws Exception {
+ MemqConfig config = new MemqConfig();
+ ClusteringConfig clusteringConfig = new ClusteringConfig();
+ clusteringConfig.setEnableAssignments(false);
+ config.setClusteringConfig(clusteringConfig);
+
+ MemqManager mgr = new MemqManager(null, config, new HashMap<>());
+ mgr.init();
+
+ MemqGovernor governor = new TestGovernor(mgr, config, new LocalEnvironmentProvider(),
+ Collections.emptySet());
+ PacketSwitchingHandler handler = new PacketSwitchingHandler(mgr, governor, null, new MetricRegistry());
+
+ WriteRequestPacket writePacket = new WriteRequestPacket();
+ writePacket.setTopicName("unknown");
+ writePacket.setData(new byte[1]);
+ RequestPacket request = new RequestPacket((short) 0, 1L, RequestType.WRITE, writePacket);
+ try {
+ handler.handle(null, request, null, null);
+ fail("Expected TOPIC_NOT_FOUND for unknown topic write when assignments are disabled");
+ } catch (Exception e) {
+ assertEquals(PacketSwitchingHandler.TOPIC_NOT_FOUND, e);
+ }
+ }
+
+ @Test
+ public void testMetadataUsesZkTopicConfigWhenAssignmentsDisabled() throws Exception {
+ MemqConfig config = new MemqConfig();
+ ClusteringConfig clusteringConfig = new ClusteringConfig();
+ clusteringConfig.setEnableAssignments(false);
+ config.setClusteringConfig(clusteringConfig);
+
+ MemqManager mgr = new MemqManager(null, config, new HashMap<>());
+ mgr.init();
+
+ TopicConfig topicConfig = new TopicConfig("loggen_loadtest", "delayeddevnull");
+ Set brokers = new HashSet<>(Arrays.asList(
+ new Broker("1.1.1.1", (short) 9092, "test", "local", BrokerType.READ_WRITE, new HashSet<>())
+ ));
+
+ MemqGovernor governor = new TestGovernorWithTopicConfig(mgr, config,
+ new LocalEnvironmentProvider(), brokers, topicConfig);
+ PacketSwitchingHandler handler = new PacketSwitchingHandler(mgr, governor, null, new MetricRegistry());
+
+ EmbeddedChannel channel = new EmbeddedChannel(new PacketSwitchingAdapter(handler));
+ RequestPacket request = new RequestPacket((short) 0, 1L, RequestType.TOPIC_METADATA,
+ new TopicMetadataRequestPacket("loggen_loadtest"));
+ channel.writeInbound(request);
+
+ ResponsePacket response = channel.readOutbound();
+ assertNotNull(response);
+ TopicMetadataResponsePacket payload = (TopicMetadataResponsePacket) response.getPacket();
+ TopicMetadata metadata = payload.getMetadata();
+ assertEquals("loggen_loadtest", metadata.getTopicName());
+ assertEquals("delayeddevnull", metadata.getStorageHandlerName());
+ assertEquals(1, metadata.getWriteBrokers().size());
+ }
+}
diff --git a/memq/src/test/java/com/pinterest/memq/core/rpc/queue/TestQueuedPacketSwitchingHandler.java b/memq/src/test/java/com/pinterest/memq/core/rpc/queue/TestQueuedPacketSwitchingHandler.java
new file mode 100644
index 0000000..bce473a
--- /dev/null
+++ b/memq/src/test/java/com/pinterest/memq/core/rpc/queue/TestQueuedPacketSwitchingHandler.java
@@ -0,0 +1,907 @@
+/**
+ * Copyright 2022 Pinterest, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.pinterest.memq.core.rpc.queue;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.pinterest.memq.commons.protocol.RequestPacket;
+import com.pinterest.memq.commons.protocol.WriteRequestPacket;
+import com.pinterest.memq.core.config.QueueingConfig;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.PooledByteBufAllocator;
+
+/**
+ * Tests for QueuedPacketSwitchingHandler and related queueing components.
+ *
+ * These tests focus on the core queueing functionality:
+ * - Enqueueing and dequeueing
+ * - Fairness across topic queues
+ * - Queue full scenarios
+ */
+public class TestQueuedPacketSwitchingHandler {
+
+ private QueueingConfig config;
+ private DeficitRoundRobinStrategy strategy;
+
+ @Before
+ public void setUp() {
+ config = new QueueingConfig();
+ config.setEnabled(true);
+ config.setDequeueThreadPoolSize(4);
+ config.setMaxQueueBytesPerTopic(1024 * 1024); // 1MB per topic
+ config.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 1024); // 1KB quantum
+
+ strategy = new DeficitRoundRobinStrategy();
+ strategy.init(config);
+ }
+
+ @After
+ public void tearDown() {
+ if (strategy != null) {
+ strategy.shutdown();
+ }
+ }
+
+ /**
+ * Helper to create a QueuedRequest with a given topic name and data size.
+ */
+ private QueuedRequest createRequest(String topicName, int dataSize) {
+ WriteRequestPacket writePacket = new WriteRequestPacket();
+ writePacket.setTopicName(topicName);
+ writePacket.setData(new byte[dataSize]);
+
+ RequestPacket requestPacket = new RequestPacket();
+ // ChannelHandlerContext is null for these unit tests - we only test queueing logic
+ return new QueuedRequest(null, requestPacket, writePacket);
+ }
+
+ // ===========================================
+ // Basic Enqueue/Dequeue Tests
+ // ===========================================
+
+ @Test
+ public void testBasicEnqueueDequeue() {
+ String topic = "test-topic";
+ QueuedRequest request = createRequest(topic, 100);
+
+ // Enqueue should succeed
+ assertTrue(strategy.enqueue(request));
+ assertEquals(1, strategy.getPendingCount());
+ assertEquals(1, strategy.getPendingCount(topic));
+
+ // Dequeue should return the request
+ Set topics = new HashSet<>();
+ topics.add(topic);
+ QueuedRequest dequeued = strategy.dequeue(topics);
+
+ assertNotNull(dequeued);
+ assertEquals(topic, dequeued.getTopicName());
+ assertEquals(0, strategy.getPendingCount());
+ }
+
+ @Test
+ public void testEnqueueMultipleTopics() {
+ String topic1 = "topic-1";
+ String topic2 = "topic-2";
+ String topic3 = "topic-3";
+
+ // Enqueue requests for different topics
+ assertTrue(strategy.enqueue(createRequest(topic1, 100)));
+ assertTrue(strategy.enqueue(createRequest(topic2, 100)));
+ assertTrue(strategy.enqueue(createRequest(topic3, 100)));
+
+ assertEquals(3, strategy.getPendingCount());
+ assertEquals(1, strategy.getPendingCount(topic1));
+ assertEquals(1, strategy.getPendingCount(topic2));
+ assertEquals(1, strategy.getPendingCount(topic3));
+
+ // Active topics should contain all three
+ Set activeTopics = strategy.getActiveTopics();
+ assertEquals(3, activeTopics.size());
+ assertTrue(activeTopics.contains(topic1));
+ assertTrue(activeTopics.contains(topic2));
+ assertTrue(activeTopics.contains(topic3));
+ }
+
+ @Test
+ public void testDequeueEmptyQueue() {
+ Set topics = new HashSet<>();
+ topics.add("nonexistent-topic");
+
+ QueuedRequest dequeued = strategy.dequeue(topics);
+ assertNull(dequeued);
+ }
+
+ @Test
+ public void testDequeueWithNoAssignedTopics() {
+ String topic = "test-topic";
+ assertTrue(strategy.enqueue(createRequest(topic, 100)));
+
+ // Dequeue with empty assigned topics should return null
+ Set emptyTopics = new HashSet<>();
+ QueuedRequest dequeued = strategy.dequeue(emptyTopics);
+ assertNull(dequeued);
+
+ // Request should still be in queue
+ assertEquals(1, strategy.getPendingCount());
+ }
+
+ @Test
+ public void testDequeueOnlyAssignedTopics() {
+ String topic1 = "topic-1";
+ String topic2 = "topic-2";
+
+ assertTrue(strategy.enqueue(createRequest(topic1, 100)));
+ assertTrue(strategy.enqueue(createRequest(topic2, 100)));
+
+ // Only assign topic1
+ Set assignedTopics = new HashSet<>();
+ assignedTopics.add(topic1);
+
+ QueuedRequest dequeued = strategy.dequeue(assignedTopics);
+ assertNotNull(dequeued);
+ assertEquals(topic1, dequeued.getTopicName());
+
+ // topic2 should still be pending
+ assertEquals(1, strategy.getPendingCount());
+ assertEquals(1, strategy.getPendingCount(topic2));
+ }
+
+ // ===========================================
+ // Queue Full Scenario Tests
+ // ===========================================
+
+ @Test
+ public void testQueueFull() {
+ // Set a very small queue size in bytes (300 bytes = 3 requests of 100 bytes)
+ QueueingConfig smallConfig = new QueueingConfig();
+ smallConfig.setMaxQueueBytesPerTopic(300);
+ smallConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 1024);
+
+ DeficitRoundRobinStrategy smallStrategy = new DeficitRoundRobinStrategy();
+ smallStrategy.init(smallConfig);
+
+ try {
+ String topic = "test-topic";
+
+ // Fill the queue (3 x 100 bytes = 300 bytes)
+ assertTrue(smallStrategy.enqueue(createRequest(topic, 100)));
+ assertTrue(smallStrategy.enqueue(createRequest(topic, 100)));
+ assertTrue(smallStrategy.enqueue(createRequest(topic, 100)));
+
+ // Fourth request should be rejected (would exceed 300 bytes)
+ assertFalse(smallStrategy.enqueue(createRequest(topic, 100)));
+
+ assertEquals(3, smallStrategy.getPendingCount());
+ assertEquals(300, smallStrategy.getPendingBytes(topic));
+ } finally {
+ smallStrategy.shutdown();
+ }
+ }
+
+ @Test
+ public void testQueueFullPerTopic() {
+ // Set a very small queue size in bytes (200 bytes per topic = 2 requests of 100 bytes)
+ QueueingConfig smallConfig = new QueueingConfig();
+ smallConfig.setMaxQueueBytesPerTopic(200);
+ smallConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 1024);
+
+ DeficitRoundRobinStrategy smallStrategy = new DeficitRoundRobinStrategy();
+ smallStrategy.init(smallConfig);
+
+ try {
+ String topic1 = "topic-1";
+ String topic2 = "topic-2";
+
+ // Fill topic1 queue (2 x 100 bytes = 200 bytes)
+ assertTrue(smallStrategy.enqueue(createRequest(topic1, 100)));
+ assertTrue(smallStrategy.enqueue(createRequest(topic1, 100)));
+ assertFalse(smallStrategy.enqueue(createRequest(topic1, 100))); // rejected
+
+ // topic2 should still accept (independent byte limit)
+ assertTrue(smallStrategy.enqueue(createRequest(topic2, 100)));
+ assertTrue(smallStrategy.enqueue(createRequest(topic2, 100)));
+ assertFalse(smallStrategy.enqueue(createRequest(topic2, 100))); // rejected
+
+ assertEquals(4, smallStrategy.getPendingCount());
+ assertEquals(2, smallStrategy.getPendingCount(topic1));
+ assertEquals(2, smallStrategy.getPendingCount(topic2));
+ assertEquals(200, smallStrategy.getPendingBytes(topic1));
+ assertEquals(200, smallStrategy.getPendingBytes(topic2));
+ } finally {
+ smallStrategy.shutdown();
+ }
+ }
+
+ @Test
+ public void testQueueFullBytesBased() {
+ // Test that queue limit is enforced by bytes, not count
+ QueueingConfig byteConfig = new QueueingConfig();
+ byteConfig.setMaxQueueBytesPerTopic(500);
+ byteConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 1024);
+
+ DeficitRoundRobinStrategy byteStrategy = new DeficitRoundRobinStrategy();
+ byteStrategy.init(byteConfig);
+
+ try {
+ String topic = "test-topic";
+
+ // Enqueue a large request (400 bytes)
+ assertTrue(byteStrategy.enqueue(createRequest(topic, 400)));
+ assertEquals(400, byteStrategy.getPendingBytes(topic));
+
+ // Enqueue a small request (50 bytes) - should succeed
+ assertTrue(byteStrategy.enqueue(createRequest(topic, 50)));
+ assertEquals(450, byteStrategy.getPendingBytes(topic));
+
+ // Another 50 bytes should succeed (total 500)
+ assertTrue(byteStrategy.enqueue(createRequest(topic, 50)));
+ assertEquals(500, byteStrategy.getPendingBytes(topic));
+
+ // Any additional bytes should be rejected
+ assertFalse(byteStrategy.enqueue(createRequest(topic, 1)));
+ assertFalse(byteStrategy.enqueue(createRequest(topic, 100)));
+
+ assertEquals(3, byteStrategy.getPendingCount());
+ assertEquals(500, byteStrategy.getPendingBytes(topic));
+ } finally {
+ byteStrategy.shutdown();
+ }
+ }
+
+ // ===========================================
+ // Fairness Tests
+ // ===========================================
+
+ @Test
+ public void testRoundRobinFairness() {
+ String topic1 = "topic-a";
+ String topic2 = "topic-b";
+ String topic3 = "topic-c";
+
+ // Enqueue multiple requests per topic (same size)
+ for (int i = 0; i < 10; i++) {
+ strategy.enqueue(createRequest(topic1, 100));
+ strategy.enqueue(createRequest(topic2, 100));
+ strategy.enqueue(createRequest(topic3, 100));
+ }
+
+ assertEquals(30, strategy.getPendingCount());
+
+ // Dequeue all and count per topic
+ Set allTopics = new HashSet<>();
+ allTopics.add(topic1);
+ allTopics.add(topic2);
+ allTopics.add(topic3);
+
+ Map dequeueCounts = new HashMap<>();
+ dequeueCounts.put(topic1, new AtomicInteger(0));
+ dequeueCounts.put(topic2, new AtomicInteger(0));
+ dequeueCounts.put(topic3, new AtomicInteger(0));
+
+ // Dequeue in batches and check fairness
+ List dequeueOrder = new ArrayList<>();
+ for (int i = 0; i < 30; i++) {
+ QueuedRequest req = strategy.dequeue(allTopics);
+ assertNotNull("Should have request at iteration " + i, req);
+ dequeueCounts.get(req.getTopicName()).incrementAndGet();
+ dequeueOrder.add(req.getTopicName());
+ }
+
+ // All topics should have been dequeued completely
+ assertEquals(10, dequeueCounts.get(topic1).get());
+ assertEquals(10, dequeueCounts.get(topic2).get());
+ assertEquals(10, dequeueCounts.get(topic3).get());
+
+ // Verify round-robin behavior - check that we alternate between topics
+ // in the first few dequeues (before any topic runs out of quantum)
+ // With DRR and same-sized requests, we should see interleaving
+ verifyInterleaving(dequeueOrder.subList(0, Math.min(15, dequeueOrder.size())));
+ }
+
+ /**
+ * Verify that the dequeue order shows interleaving between topics.
+ */
+ private void verifyInterleaving(List order) {
+ // Count consecutive same-topic dequeues
+ int maxConsecutive = 0;
+ int currentConsecutive = 1;
+
+ for (int i = 1; i < order.size(); i++) {
+ if (order.get(i).equals(order.get(i - 1))) {
+ currentConsecutive++;
+ maxConsecutive = Math.max(maxConsecutive, currentConsecutive);
+ } else {
+ currentConsecutive = 1;
+ }
+ }
+
+ // With fair queueing and equal-sized requests, we shouldn't see
+ // long runs of the same topic (indicates starvation)
+ assertTrue("Max consecutive from same topic should be reasonable: " + maxConsecutive,
+ maxConsecutive <= 5);
+ }
+
+ @Test
+ public void testDRRFairnessWithDifferentSizes() {
+ String smallTopic = "small";
+ String largeTopic = "large";
+
+ // Set quantum to 500 bytes
+ QueueingConfig drrConfig = new QueueingConfig();
+ drrConfig.setMaxQueueBytesPerTopic(1024 * 1024); // 1MB
+ drrConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 500);
+
+ DeficitRoundRobinStrategy drrStrategy = new DeficitRoundRobinStrategy();
+ drrStrategy.init(drrConfig);
+
+ try {
+ // Enqueue small requests (100 bytes each) and large requests (400 bytes each)
+ for (int i = 0; i < 20; i++) {
+ drrStrategy.enqueue(createRequest(smallTopic, 100));
+ drrStrategy.enqueue(createRequest(largeTopic, 400));
+ }
+
+ assertEquals(40, drrStrategy.getPendingCount());
+
+ Set allTopics = new HashSet<>();
+ allTopics.add(smallTopic);
+ allTopics.add(largeTopic);
+
+ Map bytesDequeued = new HashMap<>();
+ bytesDequeued.put(smallTopic, 0L);
+ bytesDequeued.put(largeTopic, 0L);
+
+ // Dequeue all requests
+ int dequeueCount = 0;
+ while (drrStrategy.getPendingCount() > 0 && dequeueCount < 50) {
+ QueuedRequest req = drrStrategy.dequeue(allTopics);
+ if (req != null) {
+ bytesDequeued.compute(req.getTopicName(),
+ (k, v) -> v + req.getSize());
+ dequeueCount++;
+ }
+ }
+
+ // Both topics should have been processed with roughly fair byte allocation
+ // Small topic: 20 * 100 = 2000 bytes
+ // Large topic: 20 * 400 = 8000 bytes
+ assertEquals(2000L, (long) bytesDequeued.get(smallTopic));
+ assertEquals(8000L, (long) bytesDequeued.get(largeTopic));
+ assertEquals(0, drrStrategy.getPendingCount());
+ } finally {
+ drrStrategy.shutdown();
+ }
+ }
+
+ @Test
+ public void testFairnessUnderLoad() {
+ // Test fairness when topics have different numbers of pending requests
+ String hotTopic = "hot-topic";
+ String coldTopic = "cold-topic";
+
+ // Hot topic has many requests
+ for (int i = 0; i < 50; i++) {
+ strategy.enqueue(createRequest(hotTopic, 100));
+ }
+
+ // Cold topic has few requests
+ for (int i = 0; i < 5; i++) {
+ strategy.enqueue(createRequest(coldTopic, 100));
+ }
+
+ Set allTopics = new HashSet<>();
+ allTopics.add(hotTopic);
+ allTopics.add(coldTopic);
+
+ // Dequeue first 10 requests and check that cold topic gets fair share
+ int coldCount = 0;
+
+ for (int i = 0; i < 10; i++) {
+ QueuedRequest req = strategy.dequeue(allTopics);
+ if (req != null && coldTopic.equals(req.getTopicName())) {
+ coldCount++;
+ }
+ }
+
+ // Cold topic should get processed, not starved
+ assertTrue("Cold topic should get some processing: " + coldCount, coldCount >= 1);
+
+ // Continue until cold topic is empty
+ while (strategy.getPendingCount(coldTopic) > 0) {
+ QueuedRequest req = strategy.dequeue(allTopics);
+ if (req != null && coldTopic.equals(req.getTopicName())) {
+ coldCount++;
+ }
+ }
+
+ assertEquals(5, coldCount);
+ }
+
+ // ===========================================
+ // Thread Assignment Tests
+ // ===========================================
+
+ @Test
+ public void testTopicThreadAssignment() {
+ // Test that topics are consistently assigned to threads
+ List topics = new ArrayList<>();
+ for (int i = 0; i < 20; i++) {
+ topics.add("topic-" + i);
+ }
+
+ int numThreads = 4;
+ Map> threadAssignments = new HashMap<>();
+
+ for (int i = 0; i < numThreads; i++) {
+ threadAssignments.put(i, new HashSet<>());
+ }
+
+ // Simulate thread assignment logic from QueuedPacketSwitchingHandler
+ for (String topic : topics) {
+ int assignedThread = Math.abs(topic.hashCode()) % numThreads;
+ threadAssignments.get(assignedThread).add(topic);
+ }
+
+ // Verify all topics are assigned
+ int totalAssigned = threadAssignments.values().stream()
+ .mapToInt(Set::size).sum();
+ assertEquals(20, totalAssigned);
+
+ // Verify each topic is assigned to exactly one thread
+ Set allAssigned = new HashSet<>();
+ for (Set assigned : threadAssignments.values()) {
+ for (String topic : assigned) {
+ assertFalse("Topic should only be assigned once: " + topic,
+ allAssigned.contains(topic));
+ allAssigned.add(topic);
+ }
+ }
+
+ // Verify distribution is somewhat balanced (not all in one thread)
+ for (int i = 0; i < numThreads; i++) {
+ int assigned = threadAssignments.get(i).size();
+ assertTrue("Thread " + i + " should have some topics, but has " + assigned,
+ assigned >= 1);
+ }
+ }
+
+ @Test
+ public void testConsistentThreadAssignment() {
+ // Verify that the same topic is always assigned to the same thread
+ String topic = "consistent-topic";
+ int numThreads = 4;
+
+ int expectedThread = Math.abs(topic.hashCode()) % numThreads;
+
+ // Check multiple times
+ for (int i = 0; i < 100; i++) {
+ int assignedThread = Math.abs(topic.hashCode()) % numThreads;
+ assertEquals(expectedThread, assignedThread);
+ }
+ }
+
+ // ===========================================
+ // Concurrent Access Tests
+ // ===========================================
+
+ @Test
+ public void testConcurrentEnqueue() throws InterruptedException {
+ int numThreads = 10;
+ int requestsPerThread = 100;
+ CountDownLatch startLatch = new CountDownLatch(1);
+ CountDownLatch doneLatch = new CountDownLatch(numThreads);
+ AtomicInteger successCount = new AtomicInteger(0);
+
+ // Create threads that enqueue concurrently
+ for (int t = 0; t < numThreads; t++) {
+ final int threadId = t;
+ new Thread(() -> {
+ try {
+ startLatch.await();
+ for (int i = 0; i < requestsPerThread; i++) {
+ String topic = "topic-" + (threadId % 3); // 3 topics
+ if (strategy.enqueue(createRequest(topic, 100))) {
+ successCount.incrementAndGet();
+ }
+ }
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ } finally {
+ doneLatch.countDown();
+ }
+ }).start();
+ }
+
+ // Start all threads
+ startLatch.countDown();
+
+ // Wait for completion
+ assertTrue(doneLatch.await(10, TimeUnit.SECONDS));
+
+ // All enqueues should succeed (queue is large enough)
+ assertEquals(numThreads * requestsPerThread, successCount.get());
+ assertEquals(numThreads * requestsPerThread, strategy.getPendingCount());
+ }
+
+ @Test
+ public void testConcurrentEnqueueDequeue() throws InterruptedException {
+ int numProducers = 5;
+ int numConsumers = 3;
+ int requestsPerProducer = 100;
+
+ CountDownLatch startLatch = new CountDownLatch(1);
+ CountDownLatch producersDone = new CountDownLatch(numProducers);
+ AtomicInteger produced = new AtomicInteger(0);
+ AtomicInteger consumed = new AtomicInteger(0);
+ AtomicInteger running = new AtomicInteger(1);
+
+ Set allTopics = new HashSet<>();
+ for (int i = 0; i < 3; i++) {
+ allTopics.add("topic-" + i);
+ }
+
+ // Start consumers
+ for (int c = 0; c < numConsumers; c++) {
+ new Thread(() -> {
+ try {
+ startLatch.await();
+ while (running.get() == 1 || strategy.getPendingCount() > 0) {
+ QueuedRequest req = strategy.dequeue(allTopics);
+ if (req != null) {
+ consumed.incrementAndGet();
+ } else {
+ Thread.sleep(1);
+ }
+ }
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ }).start();
+ }
+
+ // Start producers
+ for (int p = 0; p < numProducers; p++) {
+ final int producerId = p;
+ new Thread(() -> {
+ try {
+ startLatch.await();
+ for (int i = 0; i < requestsPerProducer; i++) {
+ String topic = "topic-" + (producerId % 3);
+ if (strategy.enqueue(createRequest(topic, 100))) {
+ produced.incrementAndGet();
+ }
+ }
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ } finally {
+ producersDone.countDown();
+ }
+ }).start();
+ }
+
+ // Start all threads
+ startLatch.countDown();
+
+ // Wait for producers to finish
+ assertTrue(producersDone.await(10, TimeUnit.SECONDS));
+
+ // Give consumers time to drain
+ Thread.sleep(500);
+ running.set(0);
+
+ // Wait a bit more for final consumption
+ Thread.sleep(200);
+
+ // All produced should be consumed
+ assertEquals(produced.get(), consumed.get());
+ assertEquals(0, strategy.getPendingCount());
+ }
+
+ // ===========================================
+ // QueueingConfig Tests
+ // ===========================================
+
+ @Test
+ public void testQueueingConfigDefaults() {
+ QueueingConfig defaultConfig = new QueueingConfig();
+
+ assertFalse(defaultConfig.isEnabled());
+ assertEquals("com.pinterest.memq.core.rpc.queue.DeficitRoundRobinStrategy",
+ defaultConfig.getStrategyClass());
+ assertEquals(4, defaultConfig.getDequeueThreadPoolSize());
+ assertEquals(100 * 1024 * 1024, defaultConfig.getMaxQueueBytesPerTopic()); // 100MB
+ assertNotNull(defaultConfig.getStrategyConfig());
+ assertTrue(defaultConfig.getStrategyConfig().isEmpty());
+ }
+
+ @Test
+ public void testQueueingConfigSetters() {
+ QueueingConfig customConfig = new QueueingConfig();
+ customConfig.setEnabled(true);
+ customConfig.setStrategyClass("com.example.CustomStrategy");
+ customConfig.setDequeueThreadPoolSize(8);
+ customConfig.setMaxQueueBytesPerTopic(50 * 1024 * 1024); // 50MB
+ customConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 32768);
+
+ assertTrue(customConfig.isEnabled());
+ assertEquals("com.example.CustomStrategy", customConfig.getStrategyClass());
+ assertEquals(8, customConfig.getDequeueThreadPoolSize());
+ assertEquals(50 * 1024 * 1024, customConfig.getMaxQueueBytesPerTopic());
+ assertEquals(32768, customConfig.getStrategyConfig().get(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM));
+ }
+
+ @Test
+ public void testDRRDefaultQuantum() {
+ // Test that DRR uses default quantum when not configured
+ QueueingConfig noQuantumConfig = new QueueingConfig();
+ noQuantumConfig.setMaxQueueBytesPerTopic(1024 * 1024);
+ // Don't set quantum in strategyConfig
+
+ DeficitRoundRobinStrategy defaultStrategy = new DeficitRoundRobinStrategy();
+ defaultStrategy.init(noQuantumConfig);
+
+ try {
+ assertEquals(DeficitRoundRobinStrategy.DEFAULT_QUANTUM, defaultStrategy.getQuantum());
+ assertEquals(65536, defaultStrategy.getQuantum()); // 64KB
+ } finally {
+ defaultStrategy.shutdown();
+ }
+ }
+
+ @Test
+ public void testDRRCustomQuantum() {
+ // Test that DRR uses custom quantum when configured
+ QueueingConfig customQuantumConfig = new QueueingConfig();
+ customQuantumConfig.setMaxQueueBytesPerTopic(1024 * 1024);
+ customQuantumConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 8192);
+
+ DeficitRoundRobinStrategy customStrategy = new DeficitRoundRobinStrategy();
+ customStrategy.init(customQuantumConfig);
+
+ try {
+ assertEquals(8192, customStrategy.getQuantum());
+ } finally {
+ customStrategy.shutdown();
+ }
+ }
+
+ @Test
+ public void testBytesDecrementOnDequeue() {
+ // Verify that bytes are correctly decremented when requests are dequeued
+ String topic = "test-topic";
+
+ assertTrue(strategy.enqueue(createRequest(topic, 100)));
+ assertTrue(strategy.enqueue(createRequest(topic, 200)));
+ assertTrue(strategy.enqueue(createRequest(topic, 300)));
+
+ assertEquals(600, strategy.getPendingBytes(topic));
+ assertEquals(3, strategy.getPendingCount(topic));
+
+ Set topics = new HashSet<>();
+ topics.add(topic);
+
+ // Dequeue first request
+ QueuedRequest req1 = strategy.dequeue(topics);
+ assertNotNull(req1);
+ assertEquals(100, req1.getSize());
+ assertEquals(500, strategy.getPendingBytes(topic));
+ assertEquals(2, strategy.getPendingCount(topic));
+
+ // Dequeue second request
+ QueuedRequest req2 = strategy.dequeue(topics);
+ assertNotNull(req2);
+ assertEquals(200, req2.getSize());
+ assertEquals(300, strategy.getPendingBytes(topic));
+ assertEquals(1, strategy.getPendingCount(topic));
+
+ // Dequeue third request
+ QueuedRequest req3 = strategy.dequeue(topics);
+ assertNotNull(req3);
+ assertEquals(300, req3.getSize());
+ assertEquals(0, strategy.getPendingBytes(topic));
+ assertEquals(0, strategy.getPendingCount(topic));
+ }
+
+ // ===========================================
+ // ByteBuf Retain/Release Tests
+ // ===========================================
+
+ @Test
+ public void testByteBufRetainReleaseOnEnqueueDequeue() {
+ // This test verifies the ByteBuf retain/release pattern that prevents
+ // "refCnt: 0" errors when requests are queued.
+ //
+ // The flow being tested:
+ // 1. Request with ByteBuf is created (refCnt = 1)
+ // 2. Before enqueueing, retain() is called (refCnt = 2)
+ // 3. Original holder releases (simulating MemqRequestDecoder finally block) (refCnt = 1)
+ // 4. After dequeue and processing, release() is called (refCnt = 0, buffer recycled)
+
+ String topic = "test-topic";
+
+ // Create a request with a pooled buffer
+ ByteBuf pooledBuffer = PooledByteBufAllocator.DEFAULT.buffer(100);
+ pooledBuffer.writeBytes(new byte[100]);
+ assertEquals(1, pooledBuffer.refCnt());
+
+ WriteRequestPacket writePacket = new WriteRequestPacket();
+ writePacket.setTopicName(topic);
+ writePacket.setData(pooledBuffer);
+
+ // Simulate what QueuedPacketSwitchingHandler.handle() does:
+ // Retain before enqueueing
+ writePacket.getData().retain();
+ assertEquals(2, pooledBuffer.refCnt());
+
+ QueuedRequest request = new QueuedRequest(null, new RequestPacket(), writePacket);
+ assertTrue(strategy.enqueue(request));
+
+ // Simulate MemqRequestDecoder finally block releasing the original buffer
+ pooledBuffer.release();
+ assertEquals(1, pooledBuffer.refCnt()); // Still alive due to retain()
+
+ // Dequeue the request
+ Set topics = new HashSet<>();
+ topics.add(topic);
+ QueuedRequest dequeued = strategy.dequeue(topics);
+ assertNotNull(dequeued);
+
+ // Buffer should still be usable
+ assertEquals(1, dequeued.getWritePacket().getData().refCnt());
+ assertTrue(dequeued.getWritePacket().getData().isReadable());
+
+ // Simulate what processQueuedRequest finally block does:
+ // Release after processing
+ dequeued.getWritePacket().getData().release();
+ assertEquals(0, pooledBuffer.refCnt()); // Now fully released
+ }
+
+ @Test
+ public void testByteBufReleaseOnEnqueueFailure() {
+ // Test that ByteBuf is properly released when enqueue fails (queue full)
+
+ QueueingConfig smallConfig = new QueueingConfig();
+ smallConfig.setMaxQueueBytesPerTopic(100); // Very small
+ smallConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 1024);
+
+ DeficitRoundRobinStrategy smallStrategy = new DeficitRoundRobinStrategy();
+ smallStrategy.init(smallConfig);
+
+ try {
+ String topic = "test-topic";
+
+ // Fill the queue
+ ByteBuf buf1 = PooledByteBufAllocator.DEFAULT.buffer(100);
+ buf1.writeBytes(new byte[100]);
+ WriteRequestPacket writePacket1 = new WriteRequestPacket();
+ writePacket1.setTopicName(topic);
+ writePacket1.setData(buf1);
+
+ // Retain before enqueue (simulating handle())
+ buf1.retain();
+ assertEquals(2, buf1.refCnt());
+
+ assertTrue(smallStrategy.enqueue(new QueuedRequest(null, new RequestPacket(), writePacket1)));
+
+ // Simulate decoder finally block
+ buf1.release();
+ assertEquals(1, buf1.refCnt());
+
+ // Now try to enqueue another request that will fail
+ ByteBuf buf2 = PooledByteBufAllocator.DEFAULT.buffer(100);
+ buf2.writeBytes(new byte[100]);
+ WriteRequestPacket writePacket2 = new WriteRequestPacket();
+ writePacket2.setTopicName(topic);
+ writePacket2.setData(buf2);
+
+ // Retain before enqueue attempt
+ buf2.retain();
+ assertEquals(2, buf2.refCnt());
+
+ // Enqueue should fail
+ boolean enqueued = smallStrategy.enqueue(new QueuedRequest(null, new RequestPacket(), writePacket2));
+ assertFalse(enqueued);
+
+ // On failure, QueuedPacketSwitchingHandler releases the buffer it retained
+ if (!enqueued) {
+ buf2.release(); // Simulating the release on failure in handle()
+ }
+ assertEquals(1, buf2.refCnt());
+
+ // Simulate decoder finally block
+ buf2.release();
+ assertEquals(0, buf2.refCnt()); // Properly released
+
+ // Clean up the first buffer
+ Set topics = new HashSet<>();
+ topics.add(topic);
+ QueuedRequest dequeued = smallStrategy.dequeue(topics);
+ assertNotNull(dequeued);
+ dequeued.getWritePacket().getData().release();
+ assertEquals(0, buf1.refCnt());
+
+ } finally {
+ smallStrategy.shutdown();
+ }
+ }
+
+ @Test
+ public void testByteBufRemainsReadableWhileQueued() {
+ // Test that the ByteBuf data remains readable while the request is queued
+
+ String topic = "test-topic";
+ byte[] testData = "Hello, World!".getBytes();
+
+ ByteBuf buffer = PooledByteBufAllocator.DEFAULT.buffer(testData.length);
+ buffer.writeBytes(testData);
+
+ WriteRequestPacket writePacket = new WriteRequestPacket();
+ writePacket.setTopicName(topic);
+ writePacket.setData(buffer);
+
+ // Retain before enqueueing
+ buffer.retain();
+ assertEquals(2, buffer.refCnt());
+
+ QueuedRequest request = new QueuedRequest(null, new RequestPacket(), writePacket);
+ assertTrue(strategy.enqueue(request));
+
+ // Simulate decoder releasing
+ buffer.release();
+ assertEquals(1, buffer.refCnt());
+
+ // Dequeue
+ Set topics = new HashSet<>();
+ topics.add(topic);
+ QueuedRequest dequeued = strategy.dequeue(topics);
+ assertNotNull(dequeued);
+
+ // Verify data is still readable
+ ByteBuf dequeuedData = dequeued.getWritePacket().getData();
+ assertEquals(testData.length, dequeuedData.readableBytes());
+
+ byte[] readData = new byte[testData.length];
+ dequeuedData.readBytes(readData);
+
+ for (int i = 0; i < testData.length; i++) {
+ assertEquals(testData[i], readData[i]);
+ }
+
+ // Clean up
+ dequeuedData.release();
+ assertEquals(0, buffer.refCnt());
+ }
+}
diff --git a/pom.xml b/pom.xml
index aee2a84..e4f8687 100644
--- a/pom.xml
+++ b/pom.xml
@@ -3,7 +3,7 @@
4.0.0
com.pinterest.memq
memq-parent
- 1.0.1
+ 1.0.2-SNAPSHOT
pom
memq-parent
Hyperscale PubSub System