From 04761a61a29201acd4b06405fff58dde45baeb19 Mon Sep 17 00:00:00 2001 From: Jeff Xiang Date: Mon, 5 Jan 2026 12:41:26 -0500 Subject: [PATCH 1/7] Emit realtime read throughput metrics from BrokerTrafficShapingHandler --- .../pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java b/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java index c47b467..3d95d9d 100644 --- a/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java @@ -10,6 +10,7 @@ public class BrokerTrafficShapingHandler extends GlobalTrafficShapingHandler { public static final String READ_LIMIT_METRIC_NAME = "broker.traffic.read.limit"; + public static final String READ_THROUGHPUT_METRIC_NAME = "broker.traffic.read.throughput"; private static final Logger logger = Logger.getLogger(BrokerTrafficShapingHandler.class.getName()); private int metricsReportingIntervalSec = 60; // default 1 minute private final MetricRegistry registry; @@ -64,6 +65,8 @@ public void startPeriodicMetricsReporting(ScheduledExecutorService executorServi */ public void reportMetrics() { long readLimit = this.getReadLimit(); + long lastReadThroughput = this.trafficCounter().lastReadThroughput(); registry.gauge(READ_LIMIT_METRIC_NAME, () -> () -> readLimit); + registry.gauge(READ_THROUGHPUT_METRIC_NAME, () -> () -> lastReadThroughput); } } From 796d19e5b91705ca11ec136a53c685a63958184a Mon Sep 17 00:00:00 2001 From: Jeff Xiang Date: Mon, 5 Jan 2026 12:42:47 -0500 Subject: [PATCH 2/7] Bump version to 1.0.2-SNAPSHOT --- deploy/pom.xml | 2 +- memq-client-all/pom.xml | 4 ++-- memq-client/pom.xml | 2 +- memq-commons/pom.xml | 2 +- memq-examples/pom.xml | 2 +- memq/pom.xml | 2 +- pom.xml | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/deploy/pom.xml b/deploy/pom.xml index 9d03b6e..f8972f4 100644 --- a/deploy/pom.xml +++ b/deploy/pom.xml @@ -5,7 +5,7 @@ com.pinterest.memq memq-parent - 1.0.1 + 1.0.2-SNAPSHOT ../pom.xml memq-deploy diff --git a/memq-client-all/pom.xml b/memq-client-all/pom.xml index c7044b9..425ac39 100644 --- a/memq-client-all/pom.xml +++ b/memq-client-all/pom.xml @@ -6,7 +6,7 @@ com.pinterest.memq memq-parent - 1.0.1 + 1.0.2-SNAPSHOT ../pom.xml memq-client-all @@ -48,7 +48,7 @@ com.pinterest.memq memq-client - 1.0.1 + 1.0.2-SNAPSHOT diff --git a/memq-client/pom.xml b/memq-client/pom.xml index a737a94..9882ca1 100644 --- a/memq-client/pom.xml +++ b/memq-client/pom.xml @@ -4,7 +4,7 @@ com.pinterest.memq memq-parent - 1.0.1 + 1.0.2-SNAPSHOT ../pom.xml memq-client diff --git a/memq-commons/pom.xml b/memq-commons/pom.xml index 24391fe..a0e21de 100644 --- a/memq-commons/pom.xml +++ b/memq-commons/pom.xml @@ -5,7 +5,7 @@ com.pinterest.memq memq-parent - 1.0.1 + 1.0.2-SNAPSHOT ../pom.xml memq-commons diff --git a/memq-examples/pom.xml b/memq-examples/pom.xml index d051e48..4baba57 100644 --- a/memq-examples/pom.xml +++ b/memq-examples/pom.xml @@ -5,7 +5,7 @@ com.pinterest.memq memq-parent - 1.0.1 + 1.0.2-SNAPSHOT ../pom.xml memq-examples diff --git a/memq/pom.xml b/memq/pom.xml index 58f7e46..99084c9 100644 --- a/memq/pom.xml +++ b/memq/pom.xml @@ -6,7 +6,7 @@ com.pinterest.memq memq-parent - 1.0.1 + 1.0.2-SNAPSHOT ../pom.xml memq diff --git a/pom.xml b/pom.xml index aee2a84..e4f8687 100644 --- a/pom.xml +++ b/pom.xml @@ -3,7 +3,7 @@ 4.0.0 com.pinterest.memq memq-parent - 1.0.1 + 1.0.2-SNAPSHOT pom memq-parent Hyperscale PubSub System From 7ba5f95926d80351b594395a11378308c4a3c372 Mon Sep 17 00:00:00 2001 From: Jeff Xiang Date: Mon, 5 Jan 2026 13:20:29 -0500 Subject: [PATCH 3/7] Emit logs for traffic counter --- .../com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java | 1 + 1 file changed, 1 insertion(+) diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java b/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java index 3d95d9d..8e897fb 100644 --- a/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java @@ -68,5 +68,6 @@ public void reportMetrics() { long lastReadThroughput = this.trafficCounter().lastReadThroughput(); registry.gauge(READ_LIMIT_METRIC_NAME, () -> () -> readLimit); registry.gauge(READ_THROUGHPUT_METRIC_NAME, () -> () -> lastReadThroughput); + logger.info("Latest traffic metrics: " + this.trafficCounter()); } } From 408be3271ceda0f5f6cbd92e2a1b717bbddd3d9e Mon Sep 17 00:00:00 2001 From: Jeff Xiang Date: Mon, 5 Jan 2026 15:21:04 -0500 Subject: [PATCH 4/7] Call lastReadThroughput method in gauge lambda supplier --- .../pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java b/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java index 8e897fb..8f3ceb2 100644 --- a/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java @@ -65,9 +65,8 @@ public void startPeriodicMetricsReporting(ScheduledExecutorService executorServi */ public void reportMetrics() { long readLimit = this.getReadLimit(); - long lastReadThroughput = this.trafficCounter().lastReadThroughput(); registry.gauge(READ_LIMIT_METRIC_NAME, () -> () -> readLimit); - registry.gauge(READ_THROUGHPUT_METRIC_NAME, () -> () -> lastReadThroughput); + registry.gauge(READ_THROUGHPUT_METRIC_NAME, () -> () -> this.trafficCounter().lastReadThroughput()); logger.info("Latest traffic metrics: " + this.trafficCounter()); } } From aaf2d215a1f411f5025167949394c7001599d39c Mon Sep 17 00:00:00 2001 From: Jeff Xiang Date: Tue, 13 Jan 2026 18:24:37 -0500 Subject: [PATCH 5/7] Simple implementation for DRR fair-queueing on broker --- .../memq/core/config/NettyServerConfig.java | 10 + .../memq/core/config/QueueingConfig.java | 95 ++ .../core/rpc/BrokerTrafficShapingHandler.java | 3 - .../memq/core/rpc/MemqNettyServer.java | 29 +- .../memq/core/rpc/MemqRequestDecoder.java | 22 +- .../rpc/queue/DeficitRoundRobinStrategy.java | 315 ++++++ .../queue/QueuedPacketSwitchingHandler.java | 298 ++++++ .../memq/core/rpc/queue/QueuedRequest.java | 73 ++ .../memq/core/rpc/queue/QueueingStrategy.java | 91 ++ .../TestMemqClientServerIntegration.java | 3 +- .../TestQueuedPacketSwitchingHandler.java | 907 ++++++++++++++++++ 11 files changed, 1831 insertions(+), 15 deletions(-) create mode 100644 memq/src/main/java/com/pinterest/memq/core/config/QueueingConfig.java create mode 100644 memq/src/main/java/com/pinterest/memq/core/rpc/queue/DeficitRoundRobinStrategy.java create mode 100644 memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java create mode 100644 memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedRequest.java create mode 100644 memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueueingStrategy.java create mode 100644 memq/src/test/java/com/pinterest/memq/core/rpc/queue/TestQueuedPacketSwitchingHandler.java diff --git a/memq/src/main/java/com/pinterest/memq/core/config/NettyServerConfig.java b/memq/src/main/java/com/pinterest/memq/core/config/NettyServerConfig.java index e462a29..c054fd4 100644 --- a/memq/src/main/java/com/pinterest/memq/core/config/NettyServerConfig.java +++ b/memq/src/main/java/com/pinterest/memq/core/config/NettyServerConfig.java @@ -28,6 +28,8 @@ public class NettyServerConfig { private int brokerInputTrafficShapingMetricsReportIntervalSec = 60; // 1 minute by default // SSL private SSLConfig sslConfig; + // Fair queueing configuration + private QueueingConfig queueingConfig = null; public int getBrokerInputTrafficShapingMetricsReportIntervalSec() { return brokerInputTrafficShapingMetricsReportIntervalSec; @@ -94,4 +96,12 @@ public void setEnableEpoll(boolean enableEpoll) { this.enableEpoll = enableEpoll; } + public QueueingConfig getQueueingConfig() { + return queueingConfig; + } + + public void setQueueingConfig(QueueingConfig queueingConfig) { + this.queueingConfig = queueingConfig; + } + } diff --git a/memq/src/main/java/com/pinterest/memq/core/config/QueueingConfig.java b/memq/src/main/java/com/pinterest/memq/core/config/QueueingConfig.java new file mode 100644 index 0000000..dd6885f --- /dev/null +++ b/memq/src/main/java/com/pinterest/memq/core/config/QueueingConfig.java @@ -0,0 +1,95 @@ +/** + * Copyright 2022 Pinterest, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.pinterest.memq.core.config; + +import com.pinterest.memq.core.rpc.queue.DeficitRoundRobinStrategy; + +import java.util.HashMap; +import java.util.Map; + +/** + * Configuration for fair-queueing mechanism in packet processing. + */ +public class QueueingConfig { + + private boolean enabled = false; + + /** + * The fully qualified class name of the QueueingStrategy implementation to use. + * Default is DeficitRoundRobinStrategy. + */ + private String strategyClass = DeficitRoundRobinStrategy.class.getName(); + + /** + * Number of threads in the dequeue thread pool. + * Each thread processes a subset of the topic queues. + */ + private int dequeueThreadPoolSize = 4; + + /** + * Maximum bytes of pending requests per topic queue. + * If a queue exceeds this limit, requests will be rejected. + * Default is 100MB per topic. + */ + private long maxQueueBytesPerTopic = 100 * 1024 * 1024; // 100MB default + + /** + * Strategy-specific configuration options. + * Keys and values depend on the strategy implementation. + * For example, DeficitRoundRobinStrategy uses "quantum" key. + */ + private Map strategyConfig = new HashMap<>(); + + public boolean isEnabled() { + return enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public String getStrategyClass() { + return strategyClass; + } + + public void setStrategyClass(String strategyClass) { + this.strategyClass = strategyClass; + } + + public int getDequeueThreadPoolSize() { + return dequeueThreadPoolSize; + } + + public void setDequeueThreadPoolSize(int dequeueThreadPoolSize) { + this.dequeueThreadPoolSize = dequeueThreadPoolSize; + } + + public long getMaxQueueBytesPerTopic() { + return maxQueueBytesPerTopic; + } + + public void setMaxQueueBytesPerTopic(long maxQueueBytesPerTopic) { + this.maxQueueBytesPerTopic = maxQueueBytesPerTopic; + } + + public Map getStrategyConfig() { + return strategyConfig; + } + + public void setStrategyConfig(Map strategyConfig) { + this.strategyConfig = strategyConfig; + } +} diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java b/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java index 8f3ceb2..c47b467 100644 --- a/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/BrokerTrafficShapingHandler.java @@ -10,7 +10,6 @@ public class BrokerTrafficShapingHandler extends GlobalTrafficShapingHandler { public static final String READ_LIMIT_METRIC_NAME = "broker.traffic.read.limit"; - public static final String READ_THROUGHPUT_METRIC_NAME = "broker.traffic.read.throughput"; private static final Logger logger = Logger.getLogger(BrokerTrafficShapingHandler.class.getName()); private int metricsReportingIntervalSec = 60; // default 1 minute private final MetricRegistry registry; @@ -66,7 +65,5 @@ public void startPeriodicMetricsReporting(ScheduledExecutorService executorServi public void reportMetrics() { long readLimit = this.getReadLimit(); registry.gauge(READ_LIMIT_METRIC_NAME, () -> () -> readLimit); - registry.gauge(READ_THROUGHPUT_METRIC_NAME, () -> () -> this.trafficCounter().lastReadThroughput()); - logger.info("Latest traffic metrics: " + this.trafficCounter()); } } diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/MemqNettyServer.java b/memq/src/main/java/com/pinterest/memq/core/rpc/MemqNettyServer.java index c400616..747d0a1 100644 --- a/memq/src/main/java/com/pinterest/memq/core/rpc/MemqNettyServer.java +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/MemqNettyServer.java @@ -41,6 +41,8 @@ import com.pinterest.memq.core.config.AuthorizerConfig; import com.pinterest.memq.core.config.MemqConfig; import com.pinterest.memq.core.config.NettyServerConfig; +import com.pinterest.memq.core.config.QueueingConfig; +import com.pinterest.memq.core.rpc.queue.QueuedPacketSwitchingHandler; import com.pinterest.memq.core.security.Authorizer; import com.pinterest.memq.core.utils.DaemonThreadFactory; import com.pinterest.memq.core.utils.MemqUtils; @@ -137,6 +139,11 @@ public void initialize() throws Exception { trafficShapingHandler.startPeriodicMetricsReporting(childGroup); } + // Create the packet switch handler ONCE and share across all connections. + // This is critical because QueuedPacketSwitchingHandler creates its own thread pool. + final PacketSwitchingHandler packetSwitchHandler = createPacketSwitchHandler( + nettyServerConfig.getQueueingConfig(), authorizer, registry); + if (useEpoll) { serverBootstrap.channel(EpollServerSocketChannel.class); } else { @@ -170,7 +177,7 @@ protected void initChannel(SocketChannel channel) throws Exception { pipeline.addLast(new LengthFieldBasedFrameDecoder(ByteOrder.BIG_ENDIAN, nettyServerConfig.getMaxFrameByteLength(), 0, Integer.BYTES, 0, 0, false)); pipeline.addLast(new MemqResponseEncoder(registry)); - pipeline.addLast(new MemqRequestDecoder(memqManager, memqGovernor, authorizer, registry)); + pipeline.addLast(new MemqRequestDecoder(packetSwitchHandler, registry)); } }); @@ -229,6 +236,26 @@ private Authorizer enableAuthenticationAuthorizationAuditing(MemqConfig configur return null; } + /** + * Create the packet switch handler. If queueing is enabled, creates a + * QueuedPacketSwitchingHandler with its own thread pool; otherwise creates + * a regular PacketSwitchingHandler. + * + * This handler is created ONCE and shared across all connections to avoid + * creating a new thread pool per connection. + */ + private PacketSwitchingHandler createPacketSwitchHandler(QueueingConfig queueingConfig, + Authorizer authorizer, + MetricRegistry registry) throws Exception { + if (queueingConfig != null && queueingConfig.isEnabled()) { + logger.info("Creating QueuedPacketSwitchingHandler with fair queueing enabled"); + return new QueuedPacketSwitchingHandler( + memqManager, memqGovernor, authorizer, registry, queueingConfig); + } else { + return new PacketSwitchingHandler(memqManager, memqGovernor, authorizer, registry); + } + } + private EventLoopGroup getEventLoopGroup(int nThreads) { if (useEpoll) { logger.info("Epoll is available and will be used"); diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/MemqRequestDecoder.java b/memq/src/main/java/com/pinterest/memq/core/rpc/MemqRequestDecoder.java index a86f7eb..a6a0781 100644 --- a/memq/src/main/java/com/pinterest/memq/core/rpc/MemqRequestDecoder.java +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/MemqRequestDecoder.java @@ -32,9 +32,6 @@ import com.pinterest.memq.commons.protocol.RequestPacket; import com.pinterest.memq.commons.protocol.ResponseCodes; import com.pinterest.memq.commons.protocol.ResponsePacket; -import com.pinterest.memq.core.MemqManager; -import com.pinterest.memq.core.clustering.MemqGovernor; -import com.pinterest.memq.core.security.Authorizer; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; @@ -45,15 +42,20 @@ public class MemqRequestDecoder extends ChannelInboundHandlerAdapter { private static final Logger logger = Logger .getLogger(MemqRequestDecoder.class.getCanonicalName()); - private Counter errorCounter; - private PacketSwitchingHandler packetSwitchHandler; + private final Counter errorCounter; + private final PacketSwitchingHandler packetSwitchHandler; - public MemqRequestDecoder(MemqManager mgr, - MemqGovernor governor, - Authorizer authorizer, + /** + * Create a MemqRequestDecoder with a shared PacketSwitchingHandler. + * This constructor should be used to avoid creating a new handler per connection. + * + * @param packetSwitchHandler the shared handler instance + * @param registry the metrics registry + */ + public MemqRequestDecoder(PacketSwitchingHandler packetSwitchHandler, MetricRegistry registry) { - errorCounter = registry.counter("request.error"); - packetSwitchHandler = new PacketSwitchingHandler(mgr, governor, authorizer, registry); + this.errorCounter = registry.counter("request.error"); + this.packetSwitchHandler = packetSwitchHandler; } @Override diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/queue/DeficitRoundRobinStrategy.java b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/DeficitRoundRobinStrategy.java new file mode 100644 index 0000000..94a9c74 --- /dev/null +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/DeficitRoundRobinStrategy.java @@ -0,0 +1,315 @@ +/** + * Copyright 2022 Pinterest, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.pinterest.memq.core.rpc.queue; + +import com.pinterest.memq.core.config.QueueingConfig; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Logger; + +/** + * Deficit Round Robin (DRR) implementation of QueueingStrategy. + * + * DRR is a fair scheduling algorithm that provides weighted fair queuing. + * Each queue maintains a deficit counter. When a queue is visited: + * - Its deficit counter is incremented by a quantum value + * - Requests are dequeued as long as their size is less than or equal to the deficit counter + * - The request's size is subtracted from the deficit counter + * + * This ensures fair bandwidth allocation across all topics while being + * work-conserving (idle queues don't waste capacity). + * + * Strategy-specific configuration (in strategyConfig map): + * - "quantum": int - bytes per round per queue (default: 64KB) + */ +public class DeficitRoundRobinStrategy extends QueueingStrategy { + + private static final Logger logger = Logger.getLogger(DeficitRoundRobinStrategy.class.getName()); + + /** + * Configuration key for the quantum value in strategyConfig. + */ + public static final String CONFIG_KEY_QUANTUM = "quantum"; + + /** + * Default quantum size in bytes (64KB). + * This determines how many bytes worth of requests can be processed + * per round for each queue. + */ + public static final int DEFAULT_QUANTUM = 65536; + + /** + * Per-topic queue holding pending requests. + */ + private final ConcurrentHashMap> queues = new ConcurrentHashMap<>(); + + /** + * Current byte count for each topic queue. + */ + private final ConcurrentHashMap queueBytes = new ConcurrentHashMap<>(); + + /** + * Deficit counter for each topic (used by DRR algorithm). + */ + private final ConcurrentHashMap deficitCounters = new ConcurrentHashMap<>(); + + /** + * Round-robin state for each dequeue thread - tracks current position in topic list. + */ + private final ConcurrentHashMap> threadIterators = new ConcurrentHashMap<>(); + + /** + * Cached list of topics for round-robin iteration. + */ + private final ConcurrentHashMap> threadTopicLists = new ConcurrentHashMap<>(); + + private int quantum; + private long maxQueueBytes; + + @Override + public void init(QueueingConfig config) { + super.init(config); + this.quantum = getQuantumFromConfig(config); + this.maxQueueBytes = config.getMaxQueueBytesPerTopic(); + } + + /** + * Extract the quantum value from strategy config, or use default. + */ + private int getQuantumFromConfig(QueueingConfig config) { + Object quantumObj = config.getStrategyConfig().get(CONFIG_KEY_QUANTUM); + if (quantumObj != null) { + if (quantumObj instanceof Number) { + return ((Number) quantumObj).intValue(); + } + if (quantumObj instanceof String) { + try { + return Integer.parseInt((String) quantumObj); + } catch (NumberFormatException e) { + logger.warning("Invalid quantum value in config: " + quantumObj + ", using default"); + } + } + } + return DEFAULT_QUANTUM; + } + + /** + * Get the current quantum value being used. + * + * @return the quantum in bytes + */ + public int getQuantum() { + return quantum; + } + + @Override + public boolean enqueue(QueuedRequest request) { + String topicName = request.getTopicName(); + int requestSize = Math.max(1, request.getSize()); + + // Get or create queue and byte counter + ConcurrentLinkedQueue queue = queues.computeIfAbsent(topicName, + k -> new ConcurrentLinkedQueue<>()); + AtomicLong bytes = queueBytes.computeIfAbsent(topicName, k -> new AtomicLong(0)); + deficitCounters.computeIfAbsent(topicName, k -> new AtomicInteger(0)); + + // Check if adding this request would exceed the byte limit + // Use CAS loop for thread safety + while (true) { + long currentBytes = bytes.get(); + if (currentBytes + requestSize > maxQueueBytes) { + logger.warning("Queue full for topic: " + topicName + + ", current bytes: " + currentBytes + ", request size: " + requestSize + + ", max: " + maxQueueBytes + ", rejecting request"); + return false; + } + if (bytes.compareAndSet(currentBytes, currentBytes + requestSize)) { + queue.offer(request); + return true; + } + // CAS failed, retry + } + } + + @Override + public QueuedRequest dequeue(Set assignedTopics) { + if (assignedTopics.isEmpty()) { + return null; + } + + long threadId = Thread.currentThread().getId(); + + // Get or create the topic list for this thread + List topicList = threadTopicLists.computeIfAbsent(threadId, k -> new ArrayList<>()); + + // Refresh the topic list if needed (topics may have been added/removed) + refreshTopicListIfNeeded(topicList, assignedTopics); + + if (topicList.isEmpty()) { + return null; + } + + // Get or create iterator for this thread + Iterator iterator = threadIterators.get(threadId); + if (iterator == null || !iterator.hasNext()) { + iterator = topicList.iterator(); + threadIterators.put(threadId, iterator); + } + + // Try each topic in round-robin fashion + int topicsChecked = 0; + int totalTopics = topicList.size(); + + while (topicsChecked < totalTopics) { + if (!iterator.hasNext()) { + iterator = topicList.iterator(); + threadIterators.put(threadId, iterator); + } + + String topicName = iterator.next(); + topicsChecked++; + + ConcurrentLinkedQueue queue = queues.get(topicName); + if (queue == null || queue.isEmpty()) { + // Reset deficit when queue is empty (work-conserving) + AtomicInteger deficit = deficitCounters.get(topicName); + if (deficit != null) { + deficit.set(0); + } + continue; + } + + AtomicInteger deficit = deficitCounters.get(topicName); + if (deficit == null) { + continue; + } + + // Peek at the next request + QueuedRequest request = queue.peek(); + if (request == null) { + continue; + } + + int requestSize = Math.max(1, request.getSize()); // Minimum size of 1 to prevent starvation + int currentDeficit = deficit.get(); + + // Add quantum to deficit + if (currentDeficit < requestSize) { + currentDeficit = deficit.addAndGet(quantum); + } + + // Check if we can service this request + if (currentDeficit >= requestSize) { + // Try to poll the request + request = queue.poll(); + if (request != null) { + // Subtract request size from deficit and update byte counter + deficit.addAndGet(-requestSize); + AtomicLong bytes = queueBytes.get(topicName); + if (bytes != null) { + bytes.addAndGet(-requestSize); + } + return request; + } + } + } + + return null; + } + + /** + * Refresh the topic list if the assigned topics have changed. + */ + private void refreshTopicListIfNeeded(List topicList, Set assignedTopics) { + // Check if we need to refresh + Set currentTopics = new HashSet<>(topicList); + Set activeAssigned = new HashSet<>(); + + for (String topic : assignedTopics) { + if (queues.containsKey(topic)) { + activeAssigned.add(topic); + } + } + + if (!currentTopics.equals(activeAssigned)) { + topicList.clear(); + topicList.addAll(activeAssigned); + // Reset iterator since list changed + threadIterators.remove(Thread.currentThread().getId()); + } + } + + @Override + public long getPendingBytes() { + long total = 0; + for (AtomicLong bytes : queueBytes.values()) { + total += bytes.get(); + } + return total; + } + + @Override + public long getPendingBytes(String topicName) { + AtomicLong bytes = queueBytes.get(topicName); + return bytes != null ? bytes.get() : 0; + } + + /** + * Get the current number of pending requests across all queues. + * + * @return total number of pending requests + */ + public int getPendingCount() { + int total = 0; + for (ConcurrentLinkedQueue queue : queues.values()) { + total += queue.size(); + } + return total; + } + + /** + * Get the current number of pending requests for a specific topic. + * + * @param topicName the topic name + * @return number of pending requests for the topic + */ + public int getPendingCount(String topicName) { + ConcurrentLinkedQueue queue = queues.get(topicName); + return queue != null ? queue.size() : 0; + } + + @Override + public Set getActiveTopics() { + return new HashSet<>(queues.keySet()); + } + + @Override + public void shutdown() { + queues.clear(); + queueBytes.clear(); + deficitCounters.clear(); + threadIterators.clear(); + threadTopicLists.clear(); + } +} diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java new file mode 100644 index 0000000..5cbd6fb --- /dev/null +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java @@ -0,0 +1,298 @@ +/** + * Copyright 2022 Pinterest, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.pinterest.memq.core.rpc.queue; + +import java.security.Principal; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Level; +import java.util.logging.Logger; + +import javax.ws.rs.ServiceUnavailableException; + +import com.codahale.metrics.Gauge; +import com.codahale.metrics.MetricRegistry; +import com.pinterest.memq.commons.protocol.RequestPacket; +import com.pinterest.memq.commons.protocol.RequestType; +import com.pinterest.memq.commons.protocol.WriteRequestPacket; +import com.pinterest.memq.core.MemqManager; +import com.pinterest.memq.core.clustering.MemqGovernor; +import com.pinterest.memq.core.config.QueueingConfig; +import com.pinterest.memq.core.rpc.PacketSwitchingHandler; +import com.pinterest.memq.core.security.Authorizer; + +import io.netty.channel.ChannelHandlerContext; + +/** + * A PacketSwitchingHandler that implements fair queueing for write requests. + * + * Write requests are enqueued into topic-specific queues and processed by + * a configurable thread pool using a fair scheduling algorithm (e.g., DRR). + * + * Read and metadata requests are processed directly without queueing. + */ +public class QueuedPacketSwitchingHandler extends PacketSwitchingHandler { + + private static final Logger logger = Logger.getLogger(QueuedPacketSwitchingHandler.class.getName()); + + private final QueueingStrategy strategy; + private final ExecutorService dequeueExecutor; + private final AtomicBoolean running = new AtomicBoolean(true); + private final ConcurrentHashMap perTopicGaugeRegistered = new ConcurrentHashMap<>(); + + /** + * Number of dequeue threads. + * Topics are distributed across threads using consistent hashing on topic name. + */ + private final int numDequeueThreads; + + public QueuedPacketSwitchingHandler(MemqManager mgr, + MemqGovernor governor, + Authorizer authorizer, + MetricRegistry registry, + QueueingConfig queueingConfig) throws Exception { + super(mgr, governor, authorizer, registry); + this.numDequeueThreads = queueingConfig.getDequeueThreadPoolSize(); + + // Initialize the queueing strategy first + this.strategy = createStrategy(queueingConfig); + + // Initialize the dequeue thread pool with named daemon threads + this.dequeueExecutor = Executors.newFixedThreadPool( + numDequeueThreads, + new ThreadFactory() { + private final AtomicInteger threadNumber = new AtomicInteger(1); + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r, "memq-dequeue-" + threadNumber.getAndIncrement()); + t.setDaemon(true); + return t; + } + }); + + // Start the dequeue worker threads + for (int i = 0; i < numDequeueThreads; i++) { + final int threadIndex = i; + dequeueExecutor.submit(() -> dequeueLoop(threadIndex)); + } + + logger.info("QueuedPacketSwitchingHandler initialized with strategy: " + + queueingConfig.getStrategyClass() + ", threads: " + numDequeueThreads); + } + + /** + * Create a QueueingStrategy instance from configuration. + */ + @SuppressWarnings("unchecked") + private QueueingStrategy createStrategy(QueueingConfig config) throws Exception { + Class strategyClass = + (Class) Class.forName(config.getStrategyClass()); + QueueingStrategy strategy = strategyClass.getDeclaredConstructor().newInstance(); + strategy.init(config); + return strategy; + } + + @Override + public void handle(ChannelHandlerContext ctx, + RequestPacket requestPacket, + Principal principal, + String clientAddress) throws Exception { + // Only queue WRITE requests + if (requestPacket.getRequestType() == RequestType.WRITE) { + WriteRequestPacket writePacket = (WriteRequestPacket) requestPacket.getPayload(); + registerPerTopicMetricsIfNeeded(writePacket.getTopicName()); + + // Retain the ByteBuf since it will be released by MemqRequestDecoder's finally block + // but we need it to stay alive until the request is dequeued and processed. + // The buffer will be released in processQueuedRequest after processing. + writePacket.getData().retain(); + + // Create queued request + QueuedRequest queuedRequest = new QueuedRequest(ctx, requestPacket, writePacket); + + // Try to enqueue + boolean enqueued = strategy.enqueue(queuedRequest); + + if (enqueued) { + MetricRegistry topicRegistry = mgr.getRegistry().get(queuedRequest.getTopicName()); + if (topicRegistry != null) { + topicRegistry.counter("queue.enqueued.bytes").inc(queuedRequest.getSize()); + } + } else { + // Release the buffer since we retained it but failed to enqueue + writePacket.getData().release(); + MetricRegistry topicRegistry = mgr.getRegistry().get(queuedRequest.getTopicName()); + if (topicRegistry != null) { + topicRegistry.counter("queue.rejected.bytes").inc(queuedRequest.getSize()); + } + // TODO: we should return a response packet with an error code signalling congestion + throw new ServiceUnavailableException("Queue full for topic: " + writePacket.getTopicName()); + } + } else { + // Process non-write requests directly (metadata, read requests) + super.handle(ctx, requestPacket, principal, clientAddress); + } + } + + /** + * Main dequeue loop for each worker thread. + * Each thread is responsible for processing a subset of topics. + */ + private void dequeueLoop(int threadIndex) { + logger.info("Dequeue worker thread " + threadIndex + " started"); + + while (running.get()) { + try { + // Get topics assigned to this thread + Set assignedTopics = getAssignedTopics(threadIndex); + + // Try to dequeue a request + QueuedRequest request = strategy.dequeue(assignedTopics); + + if (request != null) { + registerPerTopicMetricsIfNeeded(request.getTopicName()); + long queueLatencyNanos = System.nanoTime() - request.getEnqueueTimeNanos(); + MetricRegistry topicRegistry = mgr.getRegistry().get(request.getTopicName()); + if (topicRegistry != null) { + topicRegistry.timer("queue.latency.nanos").update(queueLatencyNanos, TimeUnit.NANOSECONDS); + } + if (topicRegistry != null) { + topicRegistry.counter("queue.dequeued.bytes").inc(request.getSize()); + } + processQueuedRequest(request); + } else { + // No work available, sleep briefly to avoid spinning + // TODO: event-driven approach to avoid spinning + Thread.sleep(1); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } catch (Exception e) { + logger.log(Level.SEVERE, "Error in dequeue loop", e); + } + } + + logger.info("Dequeue worker thread " + threadIndex + " stopped"); + } + + /** + * Get the set of topics assigned to a specific thread index. + * Uses consistent hashing to distribute topics across threads. + */ + private Set getAssignedTopics(int threadIndex) { + Set assignedTopics = new HashSet<>(); + + for (String topic : strategy.getActiveTopics()) { + int assignedThread = Math.abs(topic.hashCode()) % numDequeueThreads; + if (assignedThread == threadIndex) { + assignedTopics.add(topic); + } + } + + return assignedTopics; + } + + /** + * Process a dequeued request by calling the parent's executeWriteRequest. + */ + private void processQueuedRequest(QueuedRequest request) { + try { + executeWriteRequest(request.getCtx(), request.getRequestPacket(), request.getWritePacket()); + } catch (Exception e) { + logger.log(Level.SEVERE, "Error processing queued request for topic: " + + request.getTopicName(), e); + // The channel context may need error handling + handleProcessingError(request, e); + } finally { + // Release the ByteBuf that was retained when the request was enqueued. + // This balances the retain() call in handle(). + try { + request.getWritePacket().getData().release(); + } catch (Exception e) { + logger.log(Level.WARNING, "Failed to release buffer for request", e); + } + } + } + + /** + * Handle errors that occur during request processing. + */ + private void handleProcessingError(QueuedRequest request, Exception e) { + try { + ChannelHandlerContext ctx = request.getCtx(); + if (ctx.channel().isActive()) { + // Error will be handled by the exception handler in the pipeline + ctx.fireExceptionCaught(e); + } + } catch (Exception ex) { + logger.log(Level.WARNING, "Failed to handle processing error", ex); + } + } + + private void registerPerTopicMetricsIfNeeded(String topicName) { + if (perTopicGaugeRegistered.putIfAbsent(topicName, Boolean.TRUE) != null) { + return; + } + MetricRegistry topicRegistry = mgr.getRegistry().get(topicName); + if (topicRegistry == null) { + return; + } + topicRegistry.gauge("queue.pending.bytes", + () -> (Gauge) () -> strategy.getPendingBytes(topicName)); + } + + /** + * Shutdown the handler and its thread pool. + */ + public void shutdown() { + running.set(false); + strategy.shutdown(); + + dequeueExecutor.shutdown(); + try { + if (!dequeueExecutor.awaitTermination(30, TimeUnit.SECONDS)) { + dequeueExecutor.shutdownNow(); + } + } catch (InterruptedException e) { + dequeueExecutor.shutdownNow(); + Thread.currentThread().interrupt(); + } + + logger.info("QueuedPacketSwitchingHandler shutdown complete"); + } + + /** + * Get the queueing strategy being used. + */ + public QueueingStrategy getStrategy() { + return strategy; + } + + /** + * Get the pending bytes across all queues. + */ + public long getPendingBytes() { + return strategy.getPendingBytes(); + } +} diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedRequest.java b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedRequest.java new file mode 100644 index 0000000..c24daa4 --- /dev/null +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedRequest.java @@ -0,0 +1,73 @@ +/** + * Copyright 2022 Pinterest, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.pinterest.memq.core.rpc.queue; + +import com.pinterest.memq.commons.protocol.RequestPacket; +import com.pinterest.memq.commons.protocol.WriteRequestPacket; + +import io.netty.channel.ChannelHandlerContext; + +/** + * Wrapper class that holds all the context needed to process a queued request. + */ +public class QueuedRequest { + + private final ChannelHandlerContext ctx; + private final RequestPacket requestPacket; + private final WriteRequestPacket writePacket; + private final String topicName; + private final int size; + private final long enqueueTimeNanos; + + public QueuedRequest(ChannelHandlerContext ctx, + RequestPacket requestPacket, + WriteRequestPacket writePacket) { + this.ctx = ctx; + this.requestPacket = requestPacket; + this.writePacket = writePacket; + this.topicName = writePacket.getTopicName(); + this.size = writePacket.getDataLength(); + this.enqueueTimeNanos = System.nanoTime(); + } + + public ChannelHandlerContext getCtx() { + return ctx; + } + + public RequestPacket getRequestPacket() { + return requestPacket; + } + + public WriteRequestPacket getWritePacket() { + return writePacket; + } + + public String getTopicName() { + return topicName; + } + + /** + * Returns the size of the request in bytes. + * Used by strategies like DRR that need to track byte counts. + */ + public int getSize() { + return size; + } + + public long getEnqueueTimeNanos() { + return enqueueTimeNanos; + } +} diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueueingStrategy.java b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueueingStrategy.java new file mode 100644 index 0000000..5f39861 --- /dev/null +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueueingStrategy.java @@ -0,0 +1,91 @@ +/** + * Copyright 2022 Pinterest, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.pinterest.memq.core.rpc.queue; + +import com.pinterest.memq.core.config.QueueingConfig; + +import java.util.Set; + +/** + * Abstract base class for fair queueing strategies. + * + * Implementations of this class provide different algorithms for + * scheduling requests across multiple topic queues in a fair manner. + */ +public abstract class QueueingStrategy { + + protected QueueingConfig config; + + /** + * Initialize the strategy with configuration. + * Called once after instantiation. + * + * @param config the queueing configuration + */ + public void init(QueueingConfig config) { + this.config = config; + } + + /** + * Enqueue a request into the appropriate topic queue. + * + * @param request the queued request to enqueue + * @return true if the request was successfully enqueued, false if rejected + * (e.g., due to queue being full) + */ + public abstract boolean enqueue(QueuedRequest request); + + /** + * Dequeue the next request to be processed according to the strategy's + * fair scheduling algorithm. + * + * This method should only process queues that are assigned to the calling thread. + * + * @param assignedTopics the set of topic names that this thread is responsible for + * @return the next request to process, or null if no requests are available + */ + public abstract QueuedRequest dequeue(Set assignedTopics); + + /** + * Get the current pending bytes across all queues. + * + * @return total pending bytes + */ + public abstract long getPendingBytes(); + + /** + * Get the current pending bytes for a specific topic. + * + * @param topicName the topic name + * @return pending bytes for the topic + */ + public abstract long getPendingBytes(String topicName); + + /** + * Get all topic names that currently have queues (even if empty). + * + * @return set of topic names + */ + public abstract Set getActiveTopics(); + + /** + * Called when the strategy is being shut down. + * Implementations should clean up any resources. + */ + public void shutdown() { + // Default no-op implementation + } +} diff --git a/memq/src/test/java/com/pinterest/memq/core/integration/TestMemqClientServerIntegration.java b/memq/src/test/java/com/pinterest/memq/core/integration/TestMemqClientServerIntegration.java index e776eda..f4618b3 100644 --- a/memq/src/test/java/com/pinterest/memq/core/integration/TestMemqClientServerIntegration.java +++ b/memq/src/test/java/com/pinterest/memq/core/integration/TestMemqClientServerIntegration.java @@ -95,6 +95,7 @@ import com.pinterest.memq.core.rpc.MemqNettyServer; import com.pinterest.memq.core.rpc.MemqRequestDecoder; import com.pinterest.memq.core.rpc.MemqResponseEncoder; +import com.pinterest.memq.core.rpc.PacketSwitchingHandler; import com.pinterest.memq.core.rpc.TestAuditor; import com.pinterest.memq.core.utils.MiscUtils; import com.salesforce.kafka.test.junit4.SharedKafkaTestResource; @@ -141,7 +142,7 @@ public void testProducerAndServer() throws IOException { MetricRegistry registry = new MetricRegistry(); EmbeddedChannel ech = new EmbeddedChannel(new MemqResponseEncoder(registry), new LengthFieldBasedFrameDecoder(ByteOrder.BIG_ENDIAN, 2 * 1024 * 1024, 0, 4, 0, 0, false), - new MemqRequestDecoder(null, null, null, registry)); + new MemqRequestDecoder(new PacketSwitchingHandler(null, null, null, registry), registry)); ech.writeInbound(output); ech.checkException(); assertEquals(1, ech.outboundMessages().size()); diff --git a/memq/src/test/java/com/pinterest/memq/core/rpc/queue/TestQueuedPacketSwitchingHandler.java b/memq/src/test/java/com/pinterest/memq/core/rpc/queue/TestQueuedPacketSwitchingHandler.java new file mode 100644 index 0000000..bce473a --- /dev/null +++ b/memq/src/test/java/com/pinterest/memq/core/rpc/queue/TestQueuedPacketSwitchingHandler.java @@ -0,0 +1,907 @@ +/** + * Copyright 2022 Pinterest, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.pinterest.memq.core.rpc.queue; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import com.pinterest.memq.commons.protocol.RequestPacket; +import com.pinterest.memq.commons.protocol.WriteRequestPacket; +import com.pinterest.memq.core.config.QueueingConfig; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; + +/** + * Tests for QueuedPacketSwitchingHandler and related queueing components. + * + * These tests focus on the core queueing functionality: + * - Enqueueing and dequeueing + * - Fairness across topic queues + * - Queue full scenarios + */ +public class TestQueuedPacketSwitchingHandler { + + private QueueingConfig config; + private DeficitRoundRobinStrategy strategy; + + @Before + public void setUp() { + config = new QueueingConfig(); + config.setEnabled(true); + config.setDequeueThreadPoolSize(4); + config.setMaxQueueBytesPerTopic(1024 * 1024); // 1MB per topic + config.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 1024); // 1KB quantum + + strategy = new DeficitRoundRobinStrategy(); + strategy.init(config); + } + + @After + public void tearDown() { + if (strategy != null) { + strategy.shutdown(); + } + } + + /** + * Helper to create a QueuedRequest with a given topic name and data size. + */ + private QueuedRequest createRequest(String topicName, int dataSize) { + WriteRequestPacket writePacket = new WriteRequestPacket(); + writePacket.setTopicName(topicName); + writePacket.setData(new byte[dataSize]); + + RequestPacket requestPacket = new RequestPacket(); + // ChannelHandlerContext is null for these unit tests - we only test queueing logic + return new QueuedRequest(null, requestPacket, writePacket); + } + + // =========================================== + // Basic Enqueue/Dequeue Tests + // =========================================== + + @Test + public void testBasicEnqueueDequeue() { + String topic = "test-topic"; + QueuedRequest request = createRequest(topic, 100); + + // Enqueue should succeed + assertTrue(strategy.enqueue(request)); + assertEquals(1, strategy.getPendingCount()); + assertEquals(1, strategy.getPendingCount(topic)); + + // Dequeue should return the request + Set topics = new HashSet<>(); + topics.add(topic); + QueuedRequest dequeued = strategy.dequeue(topics); + + assertNotNull(dequeued); + assertEquals(topic, dequeued.getTopicName()); + assertEquals(0, strategy.getPendingCount()); + } + + @Test + public void testEnqueueMultipleTopics() { + String topic1 = "topic-1"; + String topic2 = "topic-2"; + String topic3 = "topic-3"; + + // Enqueue requests for different topics + assertTrue(strategy.enqueue(createRequest(topic1, 100))); + assertTrue(strategy.enqueue(createRequest(topic2, 100))); + assertTrue(strategy.enqueue(createRequest(topic3, 100))); + + assertEquals(3, strategy.getPendingCount()); + assertEquals(1, strategy.getPendingCount(topic1)); + assertEquals(1, strategy.getPendingCount(topic2)); + assertEquals(1, strategy.getPendingCount(topic3)); + + // Active topics should contain all three + Set activeTopics = strategy.getActiveTopics(); + assertEquals(3, activeTopics.size()); + assertTrue(activeTopics.contains(topic1)); + assertTrue(activeTopics.contains(topic2)); + assertTrue(activeTopics.contains(topic3)); + } + + @Test + public void testDequeueEmptyQueue() { + Set topics = new HashSet<>(); + topics.add("nonexistent-topic"); + + QueuedRequest dequeued = strategy.dequeue(topics); + assertNull(dequeued); + } + + @Test + public void testDequeueWithNoAssignedTopics() { + String topic = "test-topic"; + assertTrue(strategy.enqueue(createRequest(topic, 100))); + + // Dequeue with empty assigned topics should return null + Set emptyTopics = new HashSet<>(); + QueuedRequest dequeued = strategy.dequeue(emptyTopics); + assertNull(dequeued); + + // Request should still be in queue + assertEquals(1, strategy.getPendingCount()); + } + + @Test + public void testDequeueOnlyAssignedTopics() { + String topic1 = "topic-1"; + String topic2 = "topic-2"; + + assertTrue(strategy.enqueue(createRequest(topic1, 100))); + assertTrue(strategy.enqueue(createRequest(topic2, 100))); + + // Only assign topic1 + Set assignedTopics = new HashSet<>(); + assignedTopics.add(topic1); + + QueuedRequest dequeued = strategy.dequeue(assignedTopics); + assertNotNull(dequeued); + assertEquals(topic1, dequeued.getTopicName()); + + // topic2 should still be pending + assertEquals(1, strategy.getPendingCount()); + assertEquals(1, strategy.getPendingCount(topic2)); + } + + // =========================================== + // Queue Full Scenario Tests + // =========================================== + + @Test + public void testQueueFull() { + // Set a very small queue size in bytes (300 bytes = 3 requests of 100 bytes) + QueueingConfig smallConfig = new QueueingConfig(); + smallConfig.setMaxQueueBytesPerTopic(300); + smallConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 1024); + + DeficitRoundRobinStrategy smallStrategy = new DeficitRoundRobinStrategy(); + smallStrategy.init(smallConfig); + + try { + String topic = "test-topic"; + + // Fill the queue (3 x 100 bytes = 300 bytes) + assertTrue(smallStrategy.enqueue(createRequest(topic, 100))); + assertTrue(smallStrategy.enqueue(createRequest(topic, 100))); + assertTrue(smallStrategy.enqueue(createRequest(topic, 100))); + + // Fourth request should be rejected (would exceed 300 bytes) + assertFalse(smallStrategy.enqueue(createRequest(topic, 100))); + + assertEquals(3, smallStrategy.getPendingCount()); + assertEquals(300, smallStrategy.getPendingBytes(topic)); + } finally { + smallStrategy.shutdown(); + } + } + + @Test + public void testQueueFullPerTopic() { + // Set a very small queue size in bytes (200 bytes per topic = 2 requests of 100 bytes) + QueueingConfig smallConfig = new QueueingConfig(); + smallConfig.setMaxQueueBytesPerTopic(200); + smallConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 1024); + + DeficitRoundRobinStrategy smallStrategy = new DeficitRoundRobinStrategy(); + smallStrategy.init(smallConfig); + + try { + String topic1 = "topic-1"; + String topic2 = "topic-2"; + + // Fill topic1 queue (2 x 100 bytes = 200 bytes) + assertTrue(smallStrategy.enqueue(createRequest(topic1, 100))); + assertTrue(smallStrategy.enqueue(createRequest(topic1, 100))); + assertFalse(smallStrategy.enqueue(createRequest(topic1, 100))); // rejected + + // topic2 should still accept (independent byte limit) + assertTrue(smallStrategy.enqueue(createRequest(topic2, 100))); + assertTrue(smallStrategy.enqueue(createRequest(topic2, 100))); + assertFalse(smallStrategy.enqueue(createRequest(topic2, 100))); // rejected + + assertEquals(4, smallStrategy.getPendingCount()); + assertEquals(2, smallStrategy.getPendingCount(topic1)); + assertEquals(2, smallStrategy.getPendingCount(topic2)); + assertEquals(200, smallStrategy.getPendingBytes(topic1)); + assertEquals(200, smallStrategy.getPendingBytes(topic2)); + } finally { + smallStrategy.shutdown(); + } + } + + @Test + public void testQueueFullBytesBased() { + // Test that queue limit is enforced by bytes, not count + QueueingConfig byteConfig = new QueueingConfig(); + byteConfig.setMaxQueueBytesPerTopic(500); + byteConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 1024); + + DeficitRoundRobinStrategy byteStrategy = new DeficitRoundRobinStrategy(); + byteStrategy.init(byteConfig); + + try { + String topic = "test-topic"; + + // Enqueue a large request (400 bytes) + assertTrue(byteStrategy.enqueue(createRequest(topic, 400))); + assertEquals(400, byteStrategy.getPendingBytes(topic)); + + // Enqueue a small request (50 bytes) - should succeed + assertTrue(byteStrategy.enqueue(createRequest(topic, 50))); + assertEquals(450, byteStrategy.getPendingBytes(topic)); + + // Another 50 bytes should succeed (total 500) + assertTrue(byteStrategy.enqueue(createRequest(topic, 50))); + assertEquals(500, byteStrategy.getPendingBytes(topic)); + + // Any additional bytes should be rejected + assertFalse(byteStrategy.enqueue(createRequest(topic, 1))); + assertFalse(byteStrategy.enqueue(createRequest(topic, 100))); + + assertEquals(3, byteStrategy.getPendingCount()); + assertEquals(500, byteStrategy.getPendingBytes(topic)); + } finally { + byteStrategy.shutdown(); + } + } + + // =========================================== + // Fairness Tests + // =========================================== + + @Test + public void testRoundRobinFairness() { + String topic1 = "topic-a"; + String topic2 = "topic-b"; + String topic3 = "topic-c"; + + // Enqueue multiple requests per topic (same size) + for (int i = 0; i < 10; i++) { + strategy.enqueue(createRequest(topic1, 100)); + strategy.enqueue(createRequest(topic2, 100)); + strategy.enqueue(createRequest(topic3, 100)); + } + + assertEquals(30, strategy.getPendingCount()); + + // Dequeue all and count per topic + Set allTopics = new HashSet<>(); + allTopics.add(topic1); + allTopics.add(topic2); + allTopics.add(topic3); + + Map dequeueCounts = new HashMap<>(); + dequeueCounts.put(topic1, new AtomicInteger(0)); + dequeueCounts.put(topic2, new AtomicInteger(0)); + dequeueCounts.put(topic3, new AtomicInteger(0)); + + // Dequeue in batches and check fairness + List dequeueOrder = new ArrayList<>(); + for (int i = 0; i < 30; i++) { + QueuedRequest req = strategy.dequeue(allTopics); + assertNotNull("Should have request at iteration " + i, req); + dequeueCounts.get(req.getTopicName()).incrementAndGet(); + dequeueOrder.add(req.getTopicName()); + } + + // All topics should have been dequeued completely + assertEquals(10, dequeueCounts.get(topic1).get()); + assertEquals(10, dequeueCounts.get(topic2).get()); + assertEquals(10, dequeueCounts.get(topic3).get()); + + // Verify round-robin behavior - check that we alternate between topics + // in the first few dequeues (before any topic runs out of quantum) + // With DRR and same-sized requests, we should see interleaving + verifyInterleaving(dequeueOrder.subList(0, Math.min(15, dequeueOrder.size()))); + } + + /** + * Verify that the dequeue order shows interleaving between topics. + */ + private void verifyInterleaving(List order) { + // Count consecutive same-topic dequeues + int maxConsecutive = 0; + int currentConsecutive = 1; + + for (int i = 1; i < order.size(); i++) { + if (order.get(i).equals(order.get(i - 1))) { + currentConsecutive++; + maxConsecutive = Math.max(maxConsecutive, currentConsecutive); + } else { + currentConsecutive = 1; + } + } + + // With fair queueing and equal-sized requests, we shouldn't see + // long runs of the same topic (indicates starvation) + assertTrue("Max consecutive from same topic should be reasonable: " + maxConsecutive, + maxConsecutive <= 5); + } + + @Test + public void testDRRFairnessWithDifferentSizes() { + String smallTopic = "small"; + String largeTopic = "large"; + + // Set quantum to 500 bytes + QueueingConfig drrConfig = new QueueingConfig(); + drrConfig.setMaxQueueBytesPerTopic(1024 * 1024); // 1MB + drrConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 500); + + DeficitRoundRobinStrategy drrStrategy = new DeficitRoundRobinStrategy(); + drrStrategy.init(drrConfig); + + try { + // Enqueue small requests (100 bytes each) and large requests (400 bytes each) + for (int i = 0; i < 20; i++) { + drrStrategy.enqueue(createRequest(smallTopic, 100)); + drrStrategy.enqueue(createRequest(largeTopic, 400)); + } + + assertEquals(40, drrStrategy.getPendingCount()); + + Set allTopics = new HashSet<>(); + allTopics.add(smallTopic); + allTopics.add(largeTopic); + + Map bytesDequeued = new HashMap<>(); + bytesDequeued.put(smallTopic, 0L); + bytesDequeued.put(largeTopic, 0L); + + // Dequeue all requests + int dequeueCount = 0; + while (drrStrategy.getPendingCount() > 0 && dequeueCount < 50) { + QueuedRequest req = drrStrategy.dequeue(allTopics); + if (req != null) { + bytesDequeued.compute(req.getTopicName(), + (k, v) -> v + req.getSize()); + dequeueCount++; + } + } + + // Both topics should have been processed with roughly fair byte allocation + // Small topic: 20 * 100 = 2000 bytes + // Large topic: 20 * 400 = 8000 bytes + assertEquals(2000L, (long) bytesDequeued.get(smallTopic)); + assertEquals(8000L, (long) bytesDequeued.get(largeTopic)); + assertEquals(0, drrStrategy.getPendingCount()); + } finally { + drrStrategy.shutdown(); + } + } + + @Test + public void testFairnessUnderLoad() { + // Test fairness when topics have different numbers of pending requests + String hotTopic = "hot-topic"; + String coldTopic = "cold-topic"; + + // Hot topic has many requests + for (int i = 0; i < 50; i++) { + strategy.enqueue(createRequest(hotTopic, 100)); + } + + // Cold topic has few requests + for (int i = 0; i < 5; i++) { + strategy.enqueue(createRequest(coldTopic, 100)); + } + + Set allTopics = new HashSet<>(); + allTopics.add(hotTopic); + allTopics.add(coldTopic); + + // Dequeue first 10 requests and check that cold topic gets fair share + int coldCount = 0; + + for (int i = 0; i < 10; i++) { + QueuedRequest req = strategy.dequeue(allTopics); + if (req != null && coldTopic.equals(req.getTopicName())) { + coldCount++; + } + } + + // Cold topic should get processed, not starved + assertTrue("Cold topic should get some processing: " + coldCount, coldCount >= 1); + + // Continue until cold topic is empty + while (strategy.getPendingCount(coldTopic) > 0) { + QueuedRequest req = strategy.dequeue(allTopics); + if (req != null && coldTopic.equals(req.getTopicName())) { + coldCount++; + } + } + + assertEquals(5, coldCount); + } + + // =========================================== + // Thread Assignment Tests + // =========================================== + + @Test + public void testTopicThreadAssignment() { + // Test that topics are consistently assigned to threads + List topics = new ArrayList<>(); + for (int i = 0; i < 20; i++) { + topics.add("topic-" + i); + } + + int numThreads = 4; + Map> threadAssignments = new HashMap<>(); + + for (int i = 0; i < numThreads; i++) { + threadAssignments.put(i, new HashSet<>()); + } + + // Simulate thread assignment logic from QueuedPacketSwitchingHandler + for (String topic : topics) { + int assignedThread = Math.abs(topic.hashCode()) % numThreads; + threadAssignments.get(assignedThread).add(topic); + } + + // Verify all topics are assigned + int totalAssigned = threadAssignments.values().stream() + .mapToInt(Set::size).sum(); + assertEquals(20, totalAssigned); + + // Verify each topic is assigned to exactly one thread + Set allAssigned = new HashSet<>(); + for (Set assigned : threadAssignments.values()) { + for (String topic : assigned) { + assertFalse("Topic should only be assigned once: " + topic, + allAssigned.contains(topic)); + allAssigned.add(topic); + } + } + + // Verify distribution is somewhat balanced (not all in one thread) + for (int i = 0; i < numThreads; i++) { + int assigned = threadAssignments.get(i).size(); + assertTrue("Thread " + i + " should have some topics, but has " + assigned, + assigned >= 1); + } + } + + @Test + public void testConsistentThreadAssignment() { + // Verify that the same topic is always assigned to the same thread + String topic = "consistent-topic"; + int numThreads = 4; + + int expectedThread = Math.abs(topic.hashCode()) % numThreads; + + // Check multiple times + for (int i = 0; i < 100; i++) { + int assignedThread = Math.abs(topic.hashCode()) % numThreads; + assertEquals(expectedThread, assignedThread); + } + } + + // =========================================== + // Concurrent Access Tests + // =========================================== + + @Test + public void testConcurrentEnqueue() throws InterruptedException { + int numThreads = 10; + int requestsPerThread = 100; + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch doneLatch = new CountDownLatch(numThreads); + AtomicInteger successCount = new AtomicInteger(0); + + // Create threads that enqueue concurrently + for (int t = 0; t < numThreads; t++) { + final int threadId = t; + new Thread(() -> { + try { + startLatch.await(); + for (int i = 0; i < requestsPerThread; i++) { + String topic = "topic-" + (threadId % 3); // 3 topics + if (strategy.enqueue(createRequest(topic, 100))) { + successCount.incrementAndGet(); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + doneLatch.countDown(); + } + }).start(); + } + + // Start all threads + startLatch.countDown(); + + // Wait for completion + assertTrue(doneLatch.await(10, TimeUnit.SECONDS)); + + // All enqueues should succeed (queue is large enough) + assertEquals(numThreads * requestsPerThread, successCount.get()); + assertEquals(numThreads * requestsPerThread, strategy.getPendingCount()); + } + + @Test + public void testConcurrentEnqueueDequeue() throws InterruptedException { + int numProducers = 5; + int numConsumers = 3; + int requestsPerProducer = 100; + + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch producersDone = new CountDownLatch(numProducers); + AtomicInteger produced = new AtomicInteger(0); + AtomicInteger consumed = new AtomicInteger(0); + AtomicInteger running = new AtomicInteger(1); + + Set allTopics = new HashSet<>(); + for (int i = 0; i < 3; i++) { + allTopics.add("topic-" + i); + } + + // Start consumers + for (int c = 0; c < numConsumers; c++) { + new Thread(() -> { + try { + startLatch.await(); + while (running.get() == 1 || strategy.getPendingCount() > 0) { + QueuedRequest req = strategy.dequeue(allTopics); + if (req != null) { + consumed.incrementAndGet(); + } else { + Thread.sleep(1); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }).start(); + } + + // Start producers + for (int p = 0; p < numProducers; p++) { + final int producerId = p; + new Thread(() -> { + try { + startLatch.await(); + for (int i = 0; i < requestsPerProducer; i++) { + String topic = "topic-" + (producerId % 3); + if (strategy.enqueue(createRequest(topic, 100))) { + produced.incrementAndGet(); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + producersDone.countDown(); + } + }).start(); + } + + // Start all threads + startLatch.countDown(); + + // Wait for producers to finish + assertTrue(producersDone.await(10, TimeUnit.SECONDS)); + + // Give consumers time to drain + Thread.sleep(500); + running.set(0); + + // Wait a bit more for final consumption + Thread.sleep(200); + + // All produced should be consumed + assertEquals(produced.get(), consumed.get()); + assertEquals(0, strategy.getPendingCount()); + } + + // =========================================== + // QueueingConfig Tests + // =========================================== + + @Test + public void testQueueingConfigDefaults() { + QueueingConfig defaultConfig = new QueueingConfig(); + + assertFalse(defaultConfig.isEnabled()); + assertEquals("com.pinterest.memq.core.rpc.queue.DeficitRoundRobinStrategy", + defaultConfig.getStrategyClass()); + assertEquals(4, defaultConfig.getDequeueThreadPoolSize()); + assertEquals(100 * 1024 * 1024, defaultConfig.getMaxQueueBytesPerTopic()); // 100MB + assertNotNull(defaultConfig.getStrategyConfig()); + assertTrue(defaultConfig.getStrategyConfig().isEmpty()); + } + + @Test + public void testQueueingConfigSetters() { + QueueingConfig customConfig = new QueueingConfig(); + customConfig.setEnabled(true); + customConfig.setStrategyClass("com.example.CustomStrategy"); + customConfig.setDequeueThreadPoolSize(8); + customConfig.setMaxQueueBytesPerTopic(50 * 1024 * 1024); // 50MB + customConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 32768); + + assertTrue(customConfig.isEnabled()); + assertEquals("com.example.CustomStrategy", customConfig.getStrategyClass()); + assertEquals(8, customConfig.getDequeueThreadPoolSize()); + assertEquals(50 * 1024 * 1024, customConfig.getMaxQueueBytesPerTopic()); + assertEquals(32768, customConfig.getStrategyConfig().get(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM)); + } + + @Test + public void testDRRDefaultQuantum() { + // Test that DRR uses default quantum when not configured + QueueingConfig noQuantumConfig = new QueueingConfig(); + noQuantumConfig.setMaxQueueBytesPerTopic(1024 * 1024); + // Don't set quantum in strategyConfig + + DeficitRoundRobinStrategy defaultStrategy = new DeficitRoundRobinStrategy(); + defaultStrategy.init(noQuantumConfig); + + try { + assertEquals(DeficitRoundRobinStrategy.DEFAULT_QUANTUM, defaultStrategy.getQuantum()); + assertEquals(65536, defaultStrategy.getQuantum()); // 64KB + } finally { + defaultStrategy.shutdown(); + } + } + + @Test + public void testDRRCustomQuantum() { + // Test that DRR uses custom quantum when configured + QueueingConfig customQuantumConfig = new QueueingConfig(); + customQuantumConfig.setMaxQueueBytesPerTopic(1024 * 1024); + customQuantumConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 8192); + + DeficitRoundRobinStrategy customStrategy = new DeficitRoundRobinStrategy(); + customStrategy.init(customQuantumConfig); + + try { + assertEquals(8192, customStrategy.getQuantum()); + } finally { + customStrategy.shutdown(); + } + } + + @Test + public void testBytesDecrementOnDequeue() { + // Verify that bytes are correctly decremented when requests are dequeued + String topic = "test-topic"; + + assertTrue(strategy.enqueue(createRequest(topic, 100))); + assertTrue(strategy.enqueue(createRequest(topic, 200))); + assertTrue(strategy.enqueue(createRequest(topic, 300))); + + assertEquals(600, strategy.getPendingBytes(topic)); + assertEquals(3, strategy.getPendingCount(topic)); + + Set topics = new HashSet<>(); + topics.add(topic); + + // Dequeue first request + QueuedRequest req1 = strategy.dequeue(topics); + assertNotNull(req1); + assertEquals(100, req1.getSize()); + assertEquals(500, strategy.getPendingBytes(topic)); + assertEquals(2, strategy.getPendingCount(topic)); + + // Dequeue second request + QueuedRequest req2 = strategy.dequeue(topics); + assertNotNull(req2); + assertEquals(200, req2.getSize()); + assertEquals(300, strategy.getPendingBytes(topic)); + assertEquals(1, strategy.getPendingCount(topic)); + + // Dequeue third request + QueuedRequest req3 = strategy.dequeue(topics); + assertNotNull(req3); + assertEquals(300, req3.getSize()); + assertEquals(0, strategy.getPendingBytes(topic)); + assertEquals(0, strategy.getPendingCount(topic)); + } + + // =========================================== + // ByteBuf Retain/Release Tests + // =========================================== + + @Test + public void testByteBufRetainReleaseOnEnqueueDequeue() { + // This test verifies the ByteBuf retain/release pattern that prevents + // "refCnt: 0" errors when requests are queued. + // + // The flow being tested: + // 1. Request with ByteBuf is created (refCnt = 1) + // 2. Before enqueueing, retain() is called (refCnt = 2) + // 3. Original holder releases (simulating MemqRequestDecoder finally block) (refCnt = 1) + // 4. After dequeue and processing, release() is called (refCnt = 0, buffer recycled) + + String topic = "test-topic"; + + // Create a request with a pooled buffer + ByteBuf pooledBuffer = PooledByteBufAllocator.DEFAULT.buffer(100); + pooledBuffer.writeBytes(new byte[100]); + assertEquals(1, pooledBuffer.refCnt()); + + WriteRequestPacket writePacket = new WriteRequestPacket(); + writePacket.setTopicName(topic); + writePacket.setData(pooledBuffer); + + // Simulate what QueuedPacketSwitchingHandler.handle() does: + // Retain before enqueueing + writePacket.getData().retain(); + assertEquals(2, pooledBuffer.refCnt()); + + QueuedRequest request = new QueuedRequest(null, new RequestPacket(), writePacket); + assertTrue(strategy.enqueue(request)); + + // Simulate MemqRequestDecoder finally block releasing the original buffer + pooledBuffer.release(); + assertEquals(1, pooledBuffer.refCnt()); // Still alive due to retain() + + // Dequeue the request + Set topics = new HashSet<>(); + topics.add(topic); + QueuedRequest dequeued = strategy.dequeue(topics); + assertNotNull(dequeued); + + // Buffer should still be usable + assertEquals(1, dequeued.getWritePacket().getData().refCnt()); + assertTrue(dequeued.getWritePacket().getData().isReadable()); + + // Simulate what processQueuedRequest finally block does: + // Release after processing + dequeued.getWritePacket().getData().release(); + assertEquals(0, pooledBuffer.refCnt()); // Now fully released + } + + @Test + public void testByteBufReleaseOnEnqueueFailure() { + // Test that ByteBuf is properly released when enqueue fails (queue full) + + QueueingConfig smallConfig = new QueueingConfig(); + smallConfig.setMaxQueueBytesPerTopic(100); // Very small + smallConfig.getStrategyConfig().put(DeficitRoundRobinStrategy.CONFIG_KEY_QUANTUM, 1024); + + DeficitRoundRobinStrategy smallStrategy = new DeficitRoundRobinStrategy(); + smallStrategy.init(smallConfig); + + try { + String topic = "test-topic"; + + // Fill the queue + ByteBuf buf1 = PooledByteBufAllocator.DEFAULT.buffer(100); + buf1.writeBytes(new byte[100]); + WriteRequestPacket writePacket1 = new WriteRequestPacket(); + writePacket1.setTopicName(topic); + writePacket1.setData(buf1); + + // Retain before enqueue (simulating handle()) + buf1.retain(); + assertEquals(2, buf1.refCnt()); + + assertTrue(smallStrategy.enqueue(new QueuedRequest(null, new RequestPacket(), writePacket1))); + + // Simulate decoder finally block + buf1.release(); + assertEquals(1, buf1.refCnt()); + + // Now try to enqueue another request that will fail + ByteBuf buf2 = PooledByteBufAllocator.DEFAULT.buffer(100); + buf2.writeBytes(new byte[100]); + WriteRequestPacket writePacket2 = new WriteRequestPacket(); + writePacket2.setTopicName(topic); + writePacket2.setData(buf2); + + // Retain before enqueue attempt + buf2.retain(); + assertEquals(2, buf2.refCnt()); + + // Enqueue should fail + boolean enqueued = smallStrategy.enqueue(new QueuedRequest(null, new RequestPacket(), writePacket2)); + assertFalse(enqueued); + + // On failure, QueuedPacketSwitchingHandler releases the buffer it retained + if (!enqueued) { + buf2.release(); // Simulating the release on failure in handle() + } + assertEquals(1, buf2.refCnt()); + + // Simulate decoder finally block + buf2.release(); + assertEquals(0, buf2.refCnt()); // Properly released + + // Clean up the first buffer + Set topics = new HashSet<>(); + topics.add(topic); + QueuedRequest dequeued = smallStrategy.dequeue(topics); + assertNotNull(dequeued); + dequeued.getWritePacket().getData().release(); + assertEquals(0, buf1.refCnt()); + + } finally { + smallStrategy.shutdown(); + } + } + + @Test + public void testByteBufRemainsReadableWhileQueued() { + // Test that the ByteBuf data remains readable while the request is queued + + String topic = "test-topic"; + byte[] testData = "Hello, World!".getBytes(); + + ByteBuf buffer = PooledByteBufAllocator.DEFAULT.buffer(testData.length); + buffer.writeBytes(testData); + + WriteRequestPacket writePacket = new WriteRequestPacket(); + writePacket.setTopicName(topic); + writePacket.setData(buffer); + + // Retain before enqueueing + buffer.retain(); + assertEquals(2, buffer.refCnt()); + + QueuedRequest request = new QueuedRequest(null, new RequestPacket(), writePacket); + assertTrue(strategy.enqueue(request)); + + // Simulate decoder releasing + buffer.release(); + assertEquals(1, buffer.refCnt()); + + // Dequeue + Set topics = new HashSet<>(); + topics.add(topic); + QueuedRequest dequeued = strategy.dequeue(topics); + assertNotNull(dequeued); + + // Verify data is still readable + ByteBuf dequeuedData = dequeued.getWritePacket().getData(); + assertEquals(testData.length, dequeuedData.readableBytes()); + + byte[] readData = new byte[testData.length]; + dequeuedData.readBytes(readData); + + for (int i = 0; i < testData.length; i++) { + assertEquals(testData[i], readData[i]); + } + + // Clean up + dequeuedData.release(); + assertEquals(0, buffer.refCnt()); + } +} From 7b83eb32e394d62bb206c4ca0ba2d7c9598326ca Mon Sep 17 00:00:00 2001 From: Jeff Xiang Date: Thu, 15 Jan 2026 14:18:59 -0500 Subject: [PATCH 6/7] Simple implementation of disabling broker assignments --- .../com/pinterest/memq/core/MemqManager.java | 60 +++- .../memq/core/clustering/MemqGovernor.java | 54 +++- .../memq/core/config/ClusteringConfig.java | 18 ++ .../memq/core/rpc/PacketSwitchingHandler.java | 88 +++++- .../pinterest/memq/core/TestMemqManager.java | 40 +++ .../core/clustering/TestMemqGovernor.java | 20 ++ .../core/rpc/TestPacketSwitchingHandler.java | 263 ++++++++++++++++++ 7 files changed, 519 insertions(+), 24 deletions(-) create mode 100644 memq/src/test/java/com/pinterest/memq/core/rpc/TestPacketSwitchingHandler.java diff --git a/memq/src/main/java/com/pinterest/memq/core/MemqManager.java b/memq/src/main/java/com/pinterest/memq/core/MemqManager.java index d3e7305..36ccb93 100644 --- a/memq/src/main/java/com/pinterest/memq/core/MemqManager.java +++ b/memq/src/main/java/com/pinterest/memq/core/MemqManager.java @@ -62,6 +62,7 @@ public class MemqManager implements Managed { private static final Gson gson = new Gson(); private Map processorMap = new ConcurrentHashMap<>(); private Map topicMap = new ConcurrentHashMap<>(); + private Map topicLastAccessMs = new ConcurrentHashMap<>(); private MemqConfig configuration; private ScheduledExecutorService timerService; private ScheduledExecutorService cleanupService; @@ -89,9 +90,11 @@ public void init() throws Exception { if (file.exists()) { byte[] bytes = Files.readAllBytes(file.toPath()); TopicAssignment[] topics = gson.fromJson(new String(bytes), TopicAssignment[].class); - topicMap = new ConcurrentHashMap<>(); - for (TopicAssignment topicConfig : topics) { - topicMap.put(topicConfig.getTopic(), topicConfig); + if (topics != null) { + topicMap = new ConcurrentHashMap<>(); + for (TopicAssignment topicConfig : topics) { + topicMap.put(topicConfig.getTopic(), topicConfig); + } } } if (configuration.getTopicConfig() != null) { @@ -99,8 +102,12 @@ public void init() throws Exception { topicMap.put(topicConfig.getTopic(), new TopicAssignment(topicConfig, -1)); } } - for (Entry entry : topicMap.entrySet()) { - createTopicProcessor(entry.getValue()); + boolean assignmentsEnabled = configuration.getClusteringConfig() == null + || configuration.getClusteringConfig().isEnableAssignments(); + if (assignmentsEnabled) { + for (Entry entry : topicMap.entrySet()) { + createTopicProcessor(entry.getValue()); + } } } @@ -148,6 +155,7 @@ public void createTopicProcessor(TopicAssignment topicConfig) throws BadRequestE processorMap.put(topicConfig.getTopic(), tp); topicMap.put(topicConfig.getTopic(), topicConfig); + topicLastAccessMs.put(topicConfig.getTopic(), System.currentTimeMillis()); logger.info("Configured and started TopicProcessor for:" + topicConfig.getTopic()); } @@ -183,6 +191,7 @@ public Future deleteTopicProcessor(String topic) { } processorMap.remove(topic); topicMap.remove(topic); + topicLastAccessMs.remove(topic); }); } @@ -211,6 +220,47 @@ public Map getRegistry() { return metricsRegistryMap; } + public TopicAssignment getTopicAssignment(String topic) { + return topicMap.get(topic); + } + + public TopicProcessor getOrCreateTopicProcessor(String topic) throws Exception { + TopicProcessor topicProcessor = processorMap.get(topic); + if (topicProcessor != null) { + return topicProcessor; + } + TopicAssignment assignment = topicMap.get(topic); + if (assignment == null) { + throw new NotFoundException("Topic not found:" + topic); + } + createTopicProcessor(assignment); + return processorMap.get(topic); + } + + public void touchTopic(String topic) { + topicLastAccessMs.put(topic, System.currentTimeMillis()); + } + + public void startIdleTopicCleanup(long maxIdleMs) { + if (maxIdleMs <= 0) { + return; + } + cleanupService.scheduleAtFixedRate(() -> { + long now = System.currentTimeMillis(); + for (Entry entry : topicLastAccessMs.entrySet()) { + String topic = entry.getKey(); + Long lastAccess = entry.getValue(); + if (lastAccess == null) { + continue; + } + if (now - lastAccess > maxIdleMs && processorMap.containsKey(topic)) { + logger.info("Deleting idle TopicProcessor for topic:" + topic); + deleteTopicProcessor(topic); + } + } + }, maxIdleMs, maxIdleMs, TimeUnit.MILLISECONDS); + } + @Override public void start() throws Exception { logger.info("Memq manager started"); diff --git a/memq/src/main/java/com/pinterest/memq/core/clustering/MemqGovernor.java b/memq/src/main/java/com/pinterest/memq/core/clustering/MemqGovernor.java index 36c2e13..11428bd 100644 --- a/memq/src/main/java/com/pinterest/memq/core/clustering/MemqGovernor.java +++ b/memq/src/main/java/com/pinterest/memq/core/clustering/MemqGovernor.java @@ -16,6 +16,7 @@ package com.pinterest.memq.core.clustering; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -125,18 +126,22 @@ public void takeLeadership(CuratorFramework client) throws Exception { } }); - if (clusteringConfig.isEnableBalancer()) { - Balancer balancer = new Balancer(config, this, client, leaderSelector); - Thread thBalancer = new Thread(balancer); - thBalancer.setName("BalancerThread"); - thBalancer.setDaemon(true); - thBalancer.start(); - } + if (clusteringConfig.isEnableAssignments()) { + if (clusteringConfig.isEnableBalancer()) { + Balancer balancer = new Balancer(config, this, client, leaderSelector); + Thread thBalancer = new Thread(balancer); + thBalancer.setName("BalancerThread"); + thBalancer.setDaemon(true); + thBalancer.start(); + } - Thread th = new Thread(new MetadataPoller(client, topicMetadataMap)); - th.setName("MetadataPollerThread"); - th.setDaemon(true); - th.start(); + Thread th = new Thread(new MetadataPoller(client, topicMetadataMap)); + th.setName("MetadataPollerThread"); + th.setDaemon(true); + th.start(); + } else { + mgr.startIdleTopicCleanup(clusteringConfig.getMaxIdleMs()); + } if (clusteringConfig.isEnableLeaderSelector()) { leaderSelector.autoRequeue(); @@ -171,7 +176,7 @@ private void initializeZNodesAndWatchers(CuratorFramework client) throws Excepti client.create().withMode(CreateMode.EPHEMERAL).forPath(brokerZnodePath, GSON.toJson(broker).getBytes()); - if (clusteringConfig.isEnableLocalAssigner()) { + if (clusteringConfig.isEnableAssignments() && clusteringConfig.isEnableLocalAssigner()) { Thread th = new Thread(new TopicAssignmentWatcher(mgr, brokerZnodePath, broker, client)); th.setDaemon(true); th.setName("TopicAssignmentWatcher"); @@ -183,6 +188,31 @@ public Map getTopicMetadataMap() { return topicMetadataMap; } + public TopicConfig getTopicConfig(String topic) throws Exception { + if (client == null) { + return null; + } + String path = ZNODE_TOPICS_BASE + topic; + if (client.checkExists().forPath(path) == null) { + return null; + } + String topicConfig = new String(client.getData().forPath(path)); + return GSON.fromJson(topicConfig, TopicConfig.class); + } + + public Set getAllBrokers() throws Exception { + Set brokers = new HashSet<>(); + if (client == null) { + return brokers; + } + for (String id : client.getChildren().forPath(ZNODE_BROKERS)) { + String brokerInfo = new String(client.getData().forPath(ZNODE_BROKERS_BASE + id)); + Broker broker = GSON.fromJson(brokerInfo, Broker.class); + brokers.add(broker); + } + return brokers; + } + public static List convertTopicAssignmentsSetToList(Set topicAssignments) { return topicAssignments.stream().map(TopicConfig::getTopic).collect(Collectors.toList()); } diff --git a/memq/src/main/java/com/pinterest/memq/core/config/ClusteringConfig.java b/memq/src/main/java/com/pinterest/memq/core/config/ClusteringConfig.java index 228c999..b18f5c3 100644 --- a/memq/src/main/java/com/pinterest/memq/core/config/ClusteringConfig.java +++ b/memq/src/main/java/com/pinterest/memq/core/config/ClusteringConfig.java @@ -24,6 +24,8 @@ public class ClusteringConfig { private boolean enableLocalAssigner = true; private boolean addBootstrapTopics = true; private boolean enableExpiration = true; + private boolean enableAssignments = true; + private long maxIdleMs = 5 * 60 * 1000; // 5 minutes public boolean isAddBootstrapTopics() { return addBootstrapTopics; @@ -80,4 +82,20 @@ public boolean isEnableExpiration() { public void setEnableExpiration(boolean enableExpiration) { this.enableExpiration = enableExpiration; } + + public boolean isEnableAssignments() { + return enableAssignments; + } + + public void setEnableAssignments(boolean enableAssignments) { + this.enableAssignments = enableAssignments; + } + + public long getMaxIdleMs() { + return maxIdleMs; + } + + public void setMaxIdleMs(long maxIdleMs) { + this.maxIdleMs = maxIdleMs; + } } diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/PacketSwitchingHandler.java b/memq/src/main/java/com/pinterest/memq/core/rpc/PacketSwitchingHandler.java index 4940cec..08e14d6 100644 --- a/memq/src/main/java/com/pinterest/memq/core/rpc/PacketSwitchingHandler.java +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/PacketSwitchingHandler.java @@ -25,21 +25,25 @@ import com.codahale.metrics.Counter; import com.codahale.metrics.MetricRegistry; +import com.pinterest.memq.commons.protocol.Broker; import com.pinterest.memq.commons.protocol.ReadRequestPacket; import com.pinterest.memq.commons.protocol.RequestPacket; import com.pinterest.memq.commons.protocol.ResponseCodes; import com.pinterest.memq.commons.protocol.ResponsePacket; +import com.pinterest.memq.commons.protocol.TopicAssignment; import com.pinterest.memq.commons.protocol.TopicMetadata; import com.pinterest.memq.commons.protocol.TopicMetadataRequestPacket; import com.pinterest.memq.commons.protocol.TopicMetadataResponsePacket; +import com.pinterest.memq.commons.protocol.TopicConfig; import com.pinterest.memq.commons.protocol.WriteRequestPacket; -import com.pinterest.memq.commons.protocol.WriteResponsePacket; import com.pinterest.memq.commons.protocol.Broker.BrokerType; import com.pinterest.memq.core.MemqManager; import com.pinterest.memq.core.clustering.MemqGovernor; import com.pinterest.memq.core.processing.TopicProcessor; import com.pinterest.memq.core.security.Authorizer; +import java.util.Properties; + import io.netty.channel.ChannelHandlerContext; public class PacketSwitchingHandler { @@ -88,15 +92,19 @@ public void handle(ChannelHandlerContext ctx, case TOPIC_METADATA: TopicMetadataRequestPacket mdRequest = (TopicMetadataRequestPacket) requestPacket .getPayload(); - TopicMetadata md = governor.getTopicMetadataMap().get(mdRequest.getTopic()); + TopicMetadata md = null; + if (!assignmentsEnabled()) { + md = buildMetadataForAllBrokers(mdRequest.getTopic()); + } else { + md = governor.getTopicMetadataMap().get(mdRequest.getTopic()); + } if (md == null) { throw TOPIC_NOT_FOUND; - } else { - ResponsePacket msg = new ResponsePacket(requestPacket.getProtocolVersion(), - requestPacket.getClientRequestId(), requestPacket.getRequestType(), ResponseCodes.OK, - new TopicMetadataResponsePacket(md)); - ctx.writeAndFlush(msg); } + ResponsePacket msg = new ResponsePacket(requestPacket.getProtocolVersion(), + requestPacket.getClientRequestId(), requestPacket.getRequestType(), ResponseCodes.OK, + new TopicMetadataResponsePacket(md)); + ctx.writeAndFlush(msg); break; case READ: ReadRequestPacket readPacket = (ReadRequestPacket) requestPacket.getPayload(); @@ -144,6 +152,34 @@ protected void executeWriteRequest(ChannelHandlerContext ctx, throw SERVER_NOT_INITIALIZED; } TopicProcessor topicProcessor = mgr.getProcessorMap().get(writePacket.getTopicName()); + if (!assignmentsEnabled()) { + if (topicProcessor == null) { + try { + TopicAssignment assignment = mgr.getTopicAssignment(writePacket.getTopicName()); + if (assignment == null) { + TopicConfig topicConfig = governor.getTopicConfig(writePacket.getTopicName()); + if (topicConfig != null) { + assignment = new TopicAssignment(topicConfig, -1); + mgr.createTopicProcessor(assignment); + } + } + if (assignment == null) { + throw new NotFoundException("Topic not found:" + writePacket.getTopicName()); + } + topicProcessor = mgr.getOrCreateTopicProcessor(writePacket.getTopicName()); + } catch (NotFoundException e) { + logger.severe("Topic not found:" + writePacket.getTopicName()); + throw TOPIC_NOT_FOUND; + } catch (Exception e) { + throw new InternalServerErrorException(e); + } + } + writeRquestCounter.inc(); + mgr.touchTopic(writePacket.getTopicName()); + topicProcessor.registerChannel(ctx.channel()); + topicProcessor.write(requestPacket, writePacket, ctx); + return; + } if (topicProcessor != null) { writeRquestCounter.inc(); topicProcessor.registerChannel(ctx.channel()); @@ -168,4 +204,42 @@ protected void executeReadRequest(ChannelHandlerContext ctx, topicProcessor.read(requestPacket, readPacket, ctx); } } + + private boolean assignmentsEnabled() { + return mgr.getConfiguration().getClusteringConfig() == null + || mgr.getConfiguration().getClusteringConfig().isEnableAssignments(); + } + + private TopicMetadata buildMetadataForAllBrokers(String topic) throws Exception { + TopicMetadata md; + TopicAssignment assignment = mgr.getTopicAssignment(topic); + if (assignment != null) { + md = new TopicMetadata(topic, + assignment.getStorageHandlerName(), + assignment.getStorageHandlerConfig()); + } else { + TopicConfig topicConfig = governor.getTopicConfig(topic); + if (topicConfig != null) { + md = new TopicMetadata(topicConfig); + } else { + // When assignments are disabled, allow metadata for topics not yet known locally. + // Storage handler info will be empty; clients can still discover brokers. + md = new TopicMetadata(topic, null, new Properties()); + } + } + for (Broker broker : governor.getAllBrokers()) { + switch (broker.getBrokerType()) { + case READ: + md.getReadBrokers().add(broker); + break; + case WRITE: + md.getWriteBrokers().add(broker); + break; + default: + md.getReadBrokers().add(broker); + md.getWriteBrokers().add(broker); + } + } + return md; + } } diff --git a/memq/src/test/java/com/pinterest/memq/core/TestMemqManager.java b/memq/src/test/java/com/pinterest/memq/core/TestMemqManager.java index 5ef21a4..5932ba4 100644 --- a/memq/src/test/java/com/pinterest/memq/core/TestMemqManager.java +++ b/memq/src/test/java/com/pinterest/memq/core/TestMemqManager.java @@ -16,15 +16,18 @@ package com.pinterest.memq.core; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import java.io.File; import java.nio.file.Paths; import java.util.HashMap; +import java.util.concurrent.TimeUnit; import org.junit.Test; import com.pinterest.memq.commons.protocol.TopicAssignment; import com.pinterest.memq.commons.protocol.TopicConfig; +import com.pinterest.memq.core.config.ClusteringConfig; import com.pinterest.memq.core.config.MemqConfig; public class TestMemqManager { @@ -67,4 +70,41 @@ public void testTopicConfig() throws Exception { } + @Test + public void testAssignmentsDisabledLazyCreationAndIdleCleanup() throws Exception { + MemqConfig configuration = new MemqConfig(); + ClusteringConfig clusteringConfig = new ClusteringConfig(); + clusteringConfig.setEnableAssignments(false); + clusteringConfig.setMaxIdleMs(50); + configuration.setClusteringConfig(clusteringConfig); + File tmpFile = File.createTempFile("testmgrcache", "", Paths.get("/tmp").toFile()); + tmpFile.deleteOnExit(); + configuration.setTopicCacheFile(tmpFile.toString()); + tmpFile.delete(); + configuration.setTopicConfig(new TopicConfig[] { + new TopicConfig("test", "delayeddevnull") + }); + + MemqManager mgr = new MemqManager(null, configuration, new HashMap<>()); + mgr.init(); + + // No processors should be created on init when assignments are disabled. + assertEquals(0, mgr.getProcessorMap().size()); + assertEquals(1, mgr.getTopicAssignment().size()); + + // Create on first use. + assertTrue(mgr.getOrCreateTopicProcessor("test") != null); + assertEquals(1, mgr.getProcessorMap().size()); + + // Touch the topic and start idle cleanup with a short timeout. + mgr.touchTopic("test"); + mgr.startIdleTopicCleanup(25); + + long deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(2); + while (System.currentTimeMillis() < deadline && mgr.getProcessorMap().size() > 0) { + Thread.sleep(20); + } + assertEquals(0, mgr.getProcessorMap().size()); + } + } diff --git a/memq/src/test/java/com/pinterest/memq/core/clustering/TestMemqGovernor.java b/memq/src/test/java/com/pinterest/memq/core/clustering/TestMemqGovernor.java index 8b80d6e..b83b023 100644 --- a/memq/src/test/java/com/pinterest/memq/core/clustering/TestMemqGovernor.java +++ b/memq/src/test/java/com/pinterest/memq/core/clustering/TestMemqGovernor.java @@ -19,13 +19,20 @@ import java.io.FileNotFoundException; import java.io.FileReader; +import java.util.HashMap; +import java.util.Set; import org.junit.Test; import com.google.gson.Gson; import com.google.gson.JsonIOException; import com.google.gson.JsonSyntaxException; +import com.pinterest.memq.commons.protocol.Broker; import com.pinterest.memq.commons.protocol.TopicConfig; +import com.pinterest.memq.core.MemqManager; +import com.pinterest.memq.core.config.ClusteringConfig; +import com.pinterest.memq.core.config.LocalEnvironmentProvider; +import com.pinterest.memq.core.config.MemqConfig; public class TestMemqGovernor { @@ -41,4 +48,17 @@ public void testBackwardsCompatibility() throws JsonSyntaxException, JsonIOExcep assertEquals("customs3aync2", newConf.getStorageHandlerName()); } + @Test + public void testGetAllBrokersWithNoZkClient() throws Exception { + MemqConfig config = new MemqConfig(); + ClusteringConfig clusteringConfig = new ClusteringConfig(); + clusteringConfig.setEnableAssignments(false); + config.setClusteringConfig(clusteringConfig); + MemqManager mgr = new MemqManager(null, config, new HashMap<>()); + MemqGovernor governor = new MemqGovernor(mgr, config, new LocalEnvironmentProvider()); + + Set brokers = governor.getAllBrokers(); + assertEquals(0, brokers.size()); + } + } diff --git a/memq/src/test/java/com/pinterest/memq/core/rpc/TestPacketSwitchingHandler.java b/memq/src/test/java/com/pinterest/memq/core/rpc/TestPacketSwitchingHandler.java new file mode 100644 index 0000000..a0809ea --- /dev/null +++ b/memq/src/test/java/com/pinterest/memq/core/rpc/TestPacketSwitchingHandler.java @@ -0,0 +1,263 @@ +/** + * Copyright 2022 Pinterest, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.pinterest.memq.core.rpc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; + +import java.io.File; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Set; + +import org.junit.Test; + +import com.codahale.metrics.MetricRegistry; +import com.pinterest.memq.commons.protocol.Broker; +import com.pinterest.memq.commons.protocol.Broker.BrokerType; +import com.pinterest.memq.commons.protocol.RequestPacket; +import com.pinterest.memq.commons.protocol.RequestType; +import com.pinterest.memq.commons.protocol.ResponsePacket; +import com.pinterest.memq.commons.protocol.TopicMetadata; +import com.pinterest.memq.commons.protocol.TopicMetadataRequestPacket; +import com.pinterest.memq.commons.protocol.TopicMetadataResponsePacket; +import com.pinterest.memq.commons.protocol.TopicConfig; +import com.pinterest.memq.commons.protocol.WriteRequestPacket; +import com.pinterest.memq.core.MemqManager; +import com.pinterest.memq.core.clustering.MemqGovernor; +import com.pinterest.memq.core.config.ClusteringConfig; +import com.pinterest.memq.core.config.LocalEnvironmentProvider; +import com.pinterest.memq.core.config.MemqConfig; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; + +public class TestPacketSwitchingHandler { + + private static class TestGovernor extends MemqGovernor { + private final Set brokers; + + public TestGovernor(MemqManager mgr, + MemqConfig config, + LocalEnvironmentProvider provider, + Set brokers) { + super(mgr, config, provider); + this.brokers = brokers; + } + + @Override + public Set getAllBrokers() { + return brokers; + } + } + + private static class TestGovernorWithTopicConfig extends TestGovernor { + private final TopicConfig topicConfig; + + public TestGovernorWithTopicConfig(MemqManager mgr, + MemqConfig config, + LocalEnvironmentProvider provider, + Set brokers, + TopicConfig topicConfig) { + super(mgr, config, provider, brokers); + this.topicConfig = topicConfig; + } + + @Override + public TopicConfig getTopicConfig(String topic) { + if (topicConfig != null && topicConfig.getTopic().equals(topic)) { + return topicConfig; + } + return null; + } + } + + private static class PacketSwitchingAdapter extends ChannelInboundHandlerAdapter { + private final PacketSwitchingHandler handler; + + public PacketSwitchingAdapter(PacketSwitchingHandler handler) { + this.handler = handler; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + handler.handle(ctx, (RequestPacket) msg, null, null); + } + } + + @Test + public void testMetadataAllBrokersWhenAssignmentsDisabled() throws Exception { + MemqConfig config = new MemqConfig(); + ClusteringConfig clusteringConfig = new ClusteringConfig(); + clusteringConfig.setEnableAssignments(false); + config.setClusteringConfig(clusteringConfig); + File tmpFile = File.createTempFile("testmgrcache", "", Paths.get("/tmp").toFile()); + tmpFile.deleteOnExit(); + config.setTopicCacheFile(tmpFile.toString()); + config.setTopicConfig(new TopicConfig[] { + new TopicConfig("test", "delayeddevnull") + }); + + MemqManager mgr = new MemqManager(null, config, new HashMap<>()); + mgr.init(); + + Set brokers = new HashSet<>(Arrays.asList( + new Broker("1.1.1.1", (short) 9092, "test", "local", BrokerType.READ_WRITE, new HashSet<>()), + new Broker("1.1.1.2", (short) 9092, "test", "local", BrokerType.READ_WRITE, new HashSet<>()) + )); + + MemqGovernor governor = new TestGovernor(mgr, config, new LocalEnvironmentProvider(), brokers); + PacketSwitchingHandler handler = new PacketSwitchingHandler(mgr, governor, null, new MetricRegistry()); + + EmbeddedChannel channel = new EmbeddedChannel(new PacketSwitchingAdapter(handler)); + RequestPacket request = new RequestPacket((short) 0, 1L, RequestType.TOPIC_METADATA, + new TopicMetadataRequestPacket("test")); + channel.writeInbound(request); + + ResponsePacket response = channel.readOutbound(); + assertNotNull(response); + assertEquals(RequestType.TOPIC_METADATA, response.getRequestType()); + + TopicMetadataResponsePacket payload = (TopicMetadataResponsePacket) response.getPacket(); + TopicMetadata metadata = payload.getMetadata(); + assertNotNull(metadata); + assertEquals(2, metadata.getReadBrokers().size()); + assertEquals(2, metadata.getWriteBrokers().size()); + } + + @Test + public void testUnknownTopicMetadataAssignmentsEnabled() throws Exception { + MemqConfig config = new MemqConfig(); + ClusteringConfig clusteringConfig = new ClusteringConfig(); + clusteringConfig.setEnableAssignments(true); + config.setClusteringConfig(clusteringConfig); + config.setTopicConfig(new TopicConfig[] { + new TopicConfig("known", "delayeddevnull") + }); + + MemqManager mgr = new MemqManager(null, config, new HashMap<>()); + mgr.init(); + + MemqGovernor governor = new TestGovernor(mgr, config, new LocalEnvironmentProvider(), + Collections.emptySet()); + PacketSwitchingHandler handler = new PacketSwitchingHandler(mgr, governor, null, new MetricRegistry()); + + EmbeddedChannel channel = new EmbeddedChannel(new PacketSwitchingAdapter(handler)); + RequestPacket request = new RequestPacket((short) 0, 1L, RequestType.TOPIC_METADATA, + new TopicMetadataRequestPacket("unknown")); + try { + channel.writeInbound(request); + fail("Expected TOPIC_NOT_FOUND for unknown topic when assignments are enabled"); + } catch (Exception e) { + assertEquals(PacketSwitchingHandler.TOPIC_NOT_FOUND, e); + } + } + + @Test + public void testUnknownTopicMetadataAssignmentsDisabled() throws Exception { + MemqConfig config = new MemqConfig(); + ClusteringConfig clusteringConfig = new ClusteringConfig(); + clusteringConfig.setEnableAssignments(false); + config.setClusteringConfig(clusteringConfig); + + MemqManager mgr = new MemqManager(null, config, new HashMap<>()); + mgr.init(); + + Set brokers = new HashSet<>(Arrays.asList( + new Broker("1.1.1.1", (short) 9092, "test", "local", BrokerType.READ_WRITE, new HashSet<>()) + )); + + MemqGovernor governor = new TestGovernor(mgr, config, new LocalEnvironmentProvider(), brokers); + PacketSwitchingHandler handler = new PacketSwitchingHandler(mgr, governor, null, new MetricRegistry()); + + EmbeddedChannel channel = new EmbeddedChannel(new PacketSwitchingAdapter(handler)); + RequestPacket request = new RequestPacket((short) 0, 1L, RequestType.TOPIC_METADATA, + new TopicMetadataRequestPacket("unknown")); + channel.writeInbound(request); + + ResponsePacket response = channel.readOutbound(); + assertNotNull(response); + TopicMetadataResponsePacket payload = (TopicMetadataResponsePacket) response.getPacket(); + TopicMetadata metadata = payload.getMetadata(); + assertEquals("unknown", metadata.getTopicName()); + assertEquals(1, metadata.getWriteBrokers().size()); + assertEquals(1, metadata.getReadBrokers().size()); + } + + @Test + public void testUnknownTopicWriteAssignmentsDisabled() throws Exception { + MemqConfig config = new MemqConfig(); + ClusteringConfig clusteringConfig = new ClusteringConfig(); + clusteringConfig.setEnableAssignments(false); + config.setClusteringConfig(clusteringConfig); + + MemqManager mgr = new MemqManager(null, config, new HashMap<>()); + mgr.init(); + + MemqGovernor governor = new TestGovernor(mgr, config, new LocalEnvironmentProvider(), + Collections.emptySet()); + PacketSwitchingHandler handler = new PacketSwitchingHandler(mgr, governor, null, new MetricRegistry()); + + WriteRequestPacket writePacket = new WriteRequestPacket(); + writePacket.setTopicName("unknown"); + writePacket.setData(new byte[1]); + RequestPacket request = new RequestPacket((short) 0, 1L, RequestType.WRITE, writePacket); + try { + handler.handle(null, request, null, null); + fail("Expected TOPIC_NOT_FOUND for unknown topic write when assignments are disabled"); + } catch (Exception e) { + assertEquals(PacketSwitchingHandler.TOPIC_NOT_FOUND, e); + } + } + + @Test + public void testMetadataUsesZkTopicConfigWhenAssignmentsDisabled() throws Exception { + MemqConfig config = new MemqConfig(); + ClusteringConfig clusteringConfig = new ClusteringConfig(); + clusteringConfig.setEnableAssignments(false); + config.setClusteringConfig(clusteringConfig); + + MemqManager mgr = new MemqManager(null, config, new HashMap<>()); + mgr.init(); + + TopicConfig topicConfig = new TopicConfig("loggen_loadtest", "delayeddevnull"); + Set brokers = new HashSet<>(Arrays.asList( + new Broker("1.1.1.1", (short) 9092, "test", "local", BrokerType.READ_WRITE, new HashSet<>()) + )); + + MemqGovernor governor = new TestGovernorWithTopicConfig(mgr, config, + new LocalEnvironmentProvider(), brokers, topicConfig); + PacketSwitchingHandler handler = new PacketSwitchingHandler(mgr, governor, null, new MetricRegistry()); + + EmbeddedChannel channel = new EmbeddedChannel(new PacketSwitchingAdapter(handler)); + RequestPacket request = new RequestPacket((short) 0, 1L, RequestType.TOPIC_METADATA, + new TopicMetadataRequestPacket("loggen_loadtest")); + channel.writeInbound(request); + + ResponsePacket response = channel.readOutbound(); + assertNotNull(response); + TopicMetadataResponsePacket payload = (TopicMetadataResponsePacket) response.getPacket(); + TopicMetadata metadata = payload.getMetadata(); + assertEquals("loggen_loadtest", metadata.getTopicName()); + assertEquals("delayeddevnull", metadata.getStorageHandlerName()); + assertEquals(1, metadata.getWriteBrokers().size()); + } +} From ead01dd602fb012037fca9503f40a8afe850aefb Mon Sep 17 00:00:00 2001 From: Jeff Xiang Date: Thu, 15 Jan 2026 21:50:34 -0500 Subject: [PATCH 7/7] Attempt to fix queue.pending.bytes metric not showing up --- .../memq/core/rpc/queue/QueuedPacketSwitchingHandler.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java index 5cbd6fb..0e5aaee 100644 --- a/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java +++ b/memq/src/main/java/com/pinterest/memq/core/rpc/queue/QueuedPacketSwitchingHandler.java @@ -251,13 +251,13 @@ private void handleProcessingError(QueuedRequest request, Exception e) { } private void registerPerTopicMetricsIfNeeded(String topicName) { - if (perTopicGaugeRegistered.putIfAbsent(topicName, Boolean.TRUE) != null) { - return; - } MetricRegistry topicRegistry = mgr.getRegistry().get(topicName); if (topicRegistry == null) { return; } + if (perTopicGaugeRegistered.putIfAbsent(topicName, Boolean.TRUE) != null) { + return; + } topicRegistry.gauge("queue.pending.bytes", () -> (Gauge) () -> strategy.getPendingBytes(topicName)); }