diff --git a/src/main/java/io/vlingo/wire/channel/RequestChannelConsumerProvider.java b/src/main/java/io/vlingo/wire/channel/RequestChannelConsumerProvider.java index a660ff7..535e214 100644 --- a/src/main/java/io/vlingo/wire/channel/RequestChannelConsumerProvider.java +++ b/src/main/java/io/vlingo/wire/channel/RequestChannelConsumerProvider.java @@ -7,6 +7,7 @@ package io.vlingo.wire.channel; +@FunctionalInterface public interface RequestChannelConsumerProvider { RequestChannelConsumer requestChannelConsumer(); } diff --git a/src/test/java/io/vlingo/wire/fdx/bidirectional/AbstractServerChannelActorTest.java b/src/test/java/io/vlingo/wire/fdx/bidirectional/AbstractServerChannelActorTest.java new file mode 100644 index 0000000..f9a013a --- /dev/null +++ b/src/test/java/io/vlingo/wire/fdx/bidirectional/AbstractServerChannelActorTest.java @@ -0,0 +1,187 @@ +// Copyright © 2012-2018 Vaughn Vernon. All rights reserved. +// +// This Source Code Form is subject to the terms of the +// Mozilla Public License, v. 2.0. If a copy of the MPL +// was not distributed with this file, You can obtain +// one at https://mozilla.org/MPL/2.0/. +package io.vlingo.wire.fdx.bidirectional; + +import io.vlingo.actors.World; +import io.vlingo.wire.channel.RequestChannelConsumerProvider; +import io.vlingo.wire.channel.ResponseChannelConsumer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; + +public abstract class AbstractServerChannelActorTest { + private static AtomicInteger TEST_PORT = new AtomicInteger(49560); + + protected static final int POOL_SIZE = 100; + + private ClientRequestResponseChannel client; + private ServerRequestResponseChannel server; + private World world; + + protected int testPort = TEST_PORT.incrementAndGet(); + + protected abstract ClientRequestResponseChannel createClient(ResponseChannelConsumer consumer) + throws Exception; + + protected abstract ServerRequestResponseChannel createServer(World world, + RequestChannelConsumerProvider consumerProvider); + + protected abstract void request(ClientRequestResponseChannel client, final String request); + + @Before + public void setUp() { + world = World.startWithDefaults("test-request-response-channel"); + // the client is created lazily + } + + @Test + public void testBasicRequestResponse() + throws Exception { + final String request = "Hello, Request-Response"; + + TestRequestChannelConsumer serverConsumer = new TestRequestChannelConsumer(1, request.length()); + TestResponseChannelConsumer clientConsumer = new TestResponseChannelConsumer(1, request.length()); + client = createClient(clientConsumer); + server = createServer(world, () -> serverConsumer); + request(client, request); + + while (serverConsumer.remaining() > 0) { + ; + } + + while (clientConsumer.remaining() > 0) { + client.probeChannel(); + } + + assertEquals(1, serverConsumer.consumeCount()); + assertEquals(1, clientConsumer.consumeCount()); + assertEquals(clientConsumer.responses().get(0), serverConsumer.requests().get(0)); + } + + @Test + public void testGappyRequestResponse() throws Exception { + final String requestPart1 = "Request Part-1"; + final String requestPart2 = ":Request Part-2"; + final String requestPart3 = ":Request Part-3"; + + int expectedRequestLength = requestPart1.length() + requestPart2.length() + requestPart3.length(); + TestRequestChannelConsumer serverConsumer = new TestRequestChannelConsumer(1, expectedRequestLength); + TestResponseChannelConsumer clientConsumer = new TestResponseChannelConsumer(1, expectedRequestLength); + client = createClient(clientConsumer); + server = createServer(world, () -> serverConsumer); + + // simulate network latency for parts of single request + + request(client, requestPart1); + Thread.sleep(100); + request(client, requestPart2); + Thread.sleep(200); + request(client, requestPart3); + while (serverConsumer.remaining() > 0) { + ; + } + + while (clientConsumer.remaining() > 0) { + Thread.sleep(10); + client.probeChannel(); + } + + assertEquals(1, serverConsumer.consumeCount()); + assertEquals(1, clientConsumer.consumeCount()); + assertEquals(clientConsumer.responses().get(0), serverConsumer.requests().get(0)); + } + + + @Test + public void test10RequestResponse() + throws Exception { + final String request = "Hello, Request-Response"; + + int numMessages = 10; + int expectedRequestLength = request.length() + 1; // digits 0 - 9 + TestRequestChannelConsumer serverConsumer = new TestRequestChannelConsumer(numMessages, expectedRequestLength); + TestResponseChannelConsumer clientConsumer = new TestResponseChannelConsumer(numMessages, expectedRequestLength); + client = createClient(clientConsumer); + server = createServer(world, () -> serverConsumer); + + for (int idx = 0; idx < numMessages; ++idx) { + request(client, request + idx); + } + + while (clientConsumer.remaining() > 0) { + client.probeChannel(); + } + + assertEquals(numMessages, serverConsumer.consumeCount()); + assertEquals(numMessages, clientConsumer.consumeCount()); + + List requests = serverConsumer.requests(); + List responses = clientConsumer.responses(); + for (int idx = 0; idx < numMessages; ++idx) { + assertEquals(responses.get(idx), requests.get(idx)); + } + } + + @Test + public void testThatRequestResponsePoolLimitsNotExceeded() + throws Exception { + final int TOTAL = POOL_SIZE * 2; + + final String request = "Hello, Request-Response"; + + int expectedRequestLength = request.length() + 3; // digits 000 - 999 + TestRequestChannelConsumer serverConsumer = new TestRequestChannelConsumer(TOTAL, expectedRequestLength); + TestResponseChannelConsumer clientConsumer = new TestResponseChannelConsumer(TOTAL, expectedRequestLength); + client = createClient(clientConsumer); + server = createServer(world, () -> serverConsumer); + + for (int idx = 0; idx < TOTAL; ++idx) { + request(client, request + String.format("%03d", idx)); + } + + while (clientConsumer.remaining() > 0) { + client.probeChannel(); + } + + assertEquals(TOTAL, serverConsumer.consumeCount()); + assertEquals(TOTAL, clientConsumer.consumeCount()); + + List requests = serverConsumer.requests(); + List responses = clientConsumer.responses(); + for (int idx = 0; idx < TOTAL; ++idx) { + assertEquals(responses.get(idx), requests.get(idx)); + } + } + + @After + public void tearDown() { + try { + server.close(); + } catch (Exception e) { + // ignore + } + try { + client.close(); + } catch (Exception e) { + // ignore + } + + try { Thread.sleep(1000); } catch (Exception e) { } + + try { + world.terminate(); + } catch (Exception e) { + // ignore + } + } + +} \ No newline at end of file diff --git a/src/test/java/io/vlingo/wire/fdx/bidirectional/SocketRequestResponseChannelTest.java b/src/test/java/io/vlingo/wire/fdx/bidirectional/SocketRequestResponseChannelTest.java index e2b6db1..d38ead4 100644 --- a/src/test/java/io/vlingo/wire/fdx/bidirectional/SocketRequestResponseChannelTest.java +++ b/src/test/java/io/vlingo/wire/fdx/bidirectional/SocketRequestResponseChannelTest.java @@ -7,214 +7,43 @@ package io.vlingo.wire.fdx.bidirectional; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; - -import java.nio.ByteBuffer; -import java.util.concurrent.atomic.AtomicInteger; - -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - import io.vlingo.actors.Logger; import io.vlingo.actors.World; -import io.vlingo.actors.testkit.TestUntil; +import io.vlingo.wire.channel.RequestChannelConsumerProvider; +import io.vlingo.wire.channel.ResponseChannelConsumer; import io.vlingo.wire.message.ByteBufferAllocator; import io.vlingo.wire.node.Address; import io.vlingo.wire.node.AddressType; import io.vlingo.wire.node.Host; -public class SocketRequestResponseChannelTest { - private static final int POOL_SIZE = 100; - private static AtomicInteger TEST_PORT = new AtomicInteger(37370); - - private ByteBuffer buffer; - private ClientRequestResponseChannel client; - private TestResponseChannelConsumer clientConsumer; - private TestRequestChannelConsumerProvider provider; - private ServerRequestResponseChannel server; - private TestRequestChannelConsumer serverConsumer; - private World world; - - @Test - public void testBasicRequestResponse() throws Exception { - final String request = "Hello, Request-Response"; - - serverConsumer.currentExpectedRequestLength = request.length(); - clientConsumer.currentExpectedResponseLength = serverConsumer.currentExpectedRequestLength; - request(request); - - serverConsumer.untilConsume = TestUntil.happenings(1); - clientConsumer.untilConsume = TestUntil.happenings(1); - - while (serverConsumer.untilConsume.remaining() > 0) { - ; - } - serverConsumer.untilConsume.completes(); - - while (clientConsumer.untilConsume.remaining() > 0) { - client.probeChannel(); - } - clientConsumer.untilConsume.completes(); - - assertFalse(serverConsumer.requests.isEmpty()); - assertEquals(1, serverConsumer.consumeCount); - assertEquals(serverConsumer.consumeCount, serverConsumer.requests.size()); - - assertFalse(clientConsumer.responses.isEmpty()); - assertEquals(1, clientConsumer.consumeCount); - assertEquals(clientConsumer.consumeCount, clientConsumer.responses.size()); - - assertEquals(clientConsumer.responses.get(0), serverConsumer.requests.get(0)); - } - - @Test - public void testGappyRequestResponse() throws Exception { - final String requestPart1 = "Request Part-1"; - final String requestPart2 = ":Request Part-2"; - final String requestPart3 = ":Request Part-3"; - - serverConsumer.currentExpectedRequestLength = requestPart1.length() + requestPart2.length() + requestPart3.length(); - clientConsumer.currentExpectedResponseLength = serverConsumer.currentExpectedRequestLength; - - // simulate network latency for parts of single request - - request(requestPart1); - Thread.sleep(100); - request(requestPart2); - Thread.sleep(200); - request(requestPart3); - serverConsumer.untilConsume = TestUntil.happenings(1); - while (serverConsumer.untilConsume.remaining() > 0) { - ; - } - serverConsumer.untilConsume.completes(); - - clientConsumer.untilConsume = TestUntil.happenings(1); - while (clientConsumer.untilConsume.remaining() > 0) { - Thread.sleep(10); - client.probeChannel(); - } - clientConsumer.untilConsume.completes(); - - assertFalse(serverConsumer.requests.isEmpty()); - assertEquals(1, serverConsumer.consumeCount); - assertEquals(serverConsumer.consumeCount, serverConsumer.requests.size()); - - assertFalse(clientConsumer.responses.isEmpty()); - assertEquals(1, clientConsumer.consumeCount); - assertEquals(clientConsumer.consumeCount, clientConsumer.responses.size()); - - assertEquals(clientConsumer.responses.get(0), serverConsumer.requests.get(0)); - } - - @Test - public void test10RequestResponse() throws Exception { - final String request = "Hello, Request-Response"; - - serverConsumer.currentExpectedRequestLength = request.length() + 1; // digits 0 - 9 - clientConsumer.currentExpectedResponseLength = serverConsumer.currentExpectedRequestLength; - - serverConsumer.untilConsume = TestUntil.happenings(10); - clientConsumer.untilConsume = TestUntil.happenings(10); - - for (int idx = 0; idx < 10; ++idx) { - request(request + idx); - } - - while (clientConsumer.untilConsume.remaining() > 0) { - client.probeChannel(); - } - - serverConsumer.untilConsume.completes(); - clientConsumer.untilConsume.completes(); - - assertFalse(serverConsumer.requests.isEmpty()); - assertEquals(10, serverConsumer.consumeCount); - assertEquals(serverConsumer.consumeCount, serverConsumer.requests.size()); - - assertFalse(clientConsumer.responses.isEmpty()); - assertEquals(10, clientConsumer.consumeCount); - assertEquals(clientConsumer.consumeCount, clientConsumer.responses.size()); - - for (int idx = 0; idx < 10; ++idx) { - assertEquals(clientConsumer.responses.get(idx), serverConsumer.requests.get(idx)); - } - } - - @Test - public void testThatRequestResponsePoolLimitsNotExceeded() throws Exception { - final int TOTAL = POOL_SIZE * 2; - - final String request = "Hello, Request-Response"; - - serverConsumer.currentExpectedRequestLength = request.length() + 3; // digits 000 - 999 - clientConsumer.currentExpectedResponseLength = serverConsumer.currentExpectedRequestLength; - - serverConsumer.untilConsume = TestUntil.happenings(TOTAL); - clientConsumer.untilConsume = TestUntil.happenings(TOTAL); - - for (int idx = 0; idx < TOTAL; ++idx) { - request(request + String.format("%03d", idx)); - } - - while (clientConsumer.untilConsume.remaining() > 0) { - client.probeChannel(); - } - serverConsumer.untilConsume.completes(); - clientConsumer.untilConsume.completes(); - - assertFalse(serverConsumer.requests.isEmpty()); - assertEquals(TOTAL, serverConsumer.consumeCount); - assertEquals(serverConsumer.consumeCount, serverConsumer.requests.size()); - - assertFalse(clientConsumer.responses.isEmpty()); - assertEquals(TOTAL, clientConsumer.consumeCount); - assertEquals(clientConsumer.consumeCount, clientConsumer.responses.size()); +import java.nio.ByteBuffer; - for (int idx = 0; idx < TOTAL; ++idx) { - assertEquals(clientConsumer.responses.get(idx), serverConsumer.requests.get(idx)); - } - } +public class SocketRequestResponseChannelTest extends AbstractServerChannelActorTest { - @Before - public void setUp() throws Exception { - world = World.startWithDefaults("test-request-response-channel"); + private ByteBuffer buffer = ByteBufferAllocator.allocate(1024); - buffer = ByteBufferAllocator.allocate(1024); + @Override + protected ClientRequestResponseChannel createClient(ResponseChannelConsumer consumer) + throws Exception { final Logger logger = Logger.basicLogger(); - provider = new TestRequestChannelConsumerProvider(); - serverConsumer = (TestRequestChannelConsumer) provider.consumer; - - final int testPort = TEST_PORT.incrementAndGet(); - - server = ServerRequestResponseChannel.start( - world.stage(), - provider, - testPort, - "test-server", - 1, - POOL_SIZE, - 10240, - 10L); - - clientConsumer = new TestResponseChannelConsumer(); - - client = new BasicClientRequestResponseChannel(Address.from(Host.of("localhost"), testPort, AddressType.NONE), clientConsumer, POOL_SIZE, 10240, logger); + return new BasicClientRequestResponseChannel(Address.from(Host.of("localhost"), testPort, AddressType.NONE), + consumer, POOL_SIZE, 10240, logger); } - @After - public void tearDown() { - server.close(); - client.close(); - - try { Thread.sleep(1000); } catch (Exception e) { } - - world.terminate(); + @Override + protected ServerRequestResponseChannel createServer(World world, RequestChannelConsumerProvider consumerProvider) { + return ServerRequestResponseChannel.start( + world.stage(), + consumerProvider, + testPort, + "test-server", + 1, + POOL_SIZE, + 10240, + 10L); } - private void request(final String request) { + protected void request(ClientRequestResponseChannel client, final String request) { buffer.clear(); buffer.put(request.getBytes()); buffer.flip(); diff --git a/src/test/java/io/vlingo/wire/fdx/bidirectional/TestRequestChannelConsumer.java b/src/test/java/io/vlingo/wire/fdx/bidirectional/TestRequestChannelConsumer.java index 48121d5..8fe3bc5 100644 --- a/src/test/java/io/vlingo/wire/fdx/bidirectional/TestRequestChannelConsumer.java +++ b/src/test/java/io/vlingo/wire/fdx/bidirectional/TestRequestChannelConsumer.java @@ -9,8 +9,9 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; -import io.vlingo.actors.testkit.TestUntil; +import io.vlingo.actors.testkit.AccessSafely; import io.vlingo.wire.channel.RequestChannelConsumer; import io.vlingo.wire.channel.RequestResponseContext; import io.vlingo.wire.message.BasicConsumerByteBuffer; @@ -18,18 +19,30 @@ import io.vlingo.wire.message.Converters; public class TestRequestChannelConsumer implements RequestChannelConsumer { - public int currentExpectedRequestLength; - public int consumeCount; - public List requests = new ArrayList<>(); - public TestUntil untilClosed; - public TestUntil untilConsume; - + private static final String ID_REQUESTS = "requests"; + + + private final AtomicBoolean untilClosed; + private final AccessSafely accessSafely; + private final int happenings; + private final int currentExpectedRequestLength; + private StringBuilder requestBuilder = new StringBuilder(); private String remaining = ""; + public TestRequestChannelConsumer(int happenings, int currentExpectedRequestLength) { + this.happenings = happenings; + this.currentExpectedRequestLength = currentExpectedRequestLength; + List requests = new ArrayList<>(); + this.untilClosed = new AtomicBoolean(false); + this.accessSafely = AccessSafely.afterCompleting(happenings) + .writingWith(ID_REQUESTS, msg -> requests.add((String) msg)) + .readingWith(ID_REQUESTS, () -> new ArrayList<>(requests)); + } + @Override public void closeWith(final RequestResponseContext requestResponseContext, final Object data) { - if (untilClosed != null) untilClosed.happened(); + untilClosed.set(true); } @Override @@ -44,27 +57,36 @@ public void consume(RequestResponseContext context, final ConsumerByteBuffer final String combinedRequests = requestBuilder.toString(); final int combinedLength = combinedRequests.length(); requestBuilder.setLength(0); // reuse - + int currentIndex = 0; boolean last = false; while (!last) { - final int endIndex = currentIndex+currentExpectedRequestLength; + final int endIndex = currentIndex + currentExpectedRequestLength; if (endIndex > combinedRequests.length()) { remaining = combinedRequests.substring(currentIndex); return; } final String request = combinedRequests.substring(currentIndex, endIndex); currentIndex += currentExpectedRequestLength; - requests.add(request); - ++consumeCount; - + accessSafely.writeUsing(ID_REQUESTS, request); + final ConsumerByteBuffer responseBuffer = new BasicConsumerByteBuffer(1, currentExpectedRequestLength); context.respondWith(responseBuffer.clear().put(request.getBytes()).flip()); // echo back - + last = currentIndex == combinedLength; - - if (untilConsume != null) untilConsume.happened(); } } } + + public int remaining() { + return happenings - accessSafely.totalWrites(); + } + + public int consumeCount() { + return accessSafely.totalWrites(); + } + + public List requests() { + return accessSafely.readFrom(ID_REQUESTS); + } } diff --git a/src/test/java/io/vlingo/wire/fdx/bidirectional/TestRequestChannelConsumerProvider.java b/src/test/java/io/vlingo/wire/fdx/bidirectional/TestRequestChannelConsumerProvider.java deleted file mode 100644 index 43a137a..0000000 --- a/src/test/java/io/vlingo/wire/fdx/bidirectional/TestRequestChannelConsumerProvider.java +++ /dev/null @@ -1,15 +0,0 @@ -package io.vlingo.wire.fdx.bidirectional; - -import io.vlingo.actors.testkit.TestUntil; -import io.vlingo.wire.channel.RequestChannelConsumer; -import io.vlingo.wire.channel.RequestChannelConsumerProvider; - -public class TestRequestChannelConsumerProvider implements RequestChannelConsumerProvider { - public TestUntil until; - public RequestChannelConsumer consumer = new TestRequestChannelConsumer(); - - @Override - public RequestChannelConsumer requestChannelConsumer() { - return consumer; - } -} diff --git a/src/test/java/io/vlingo/wire/fdx/bidirectional/TestResponseChannelConsumer.java b/src/test/java/io/vlingo/wire/fdx/bidirectional/TestResponseChannelConsumer.java index 480f240..a821531 100644 --- a/src/test/java/io/vlingo/wire/fdx/bidirectional/TestResponseChannelConsumer.java +++ b/src/test/java/io/vlingo/wire/fdx/bidirectional/TestResponseChannelConsumer.java @@ -10,19 +10,29 @@ import java.util.ArrayList; import java.util.List; -import io.vlingo.actors.testkit.TestUntil; +import io.vlingo.actors.testkit.AccessSafely; import io.vlingo.wire.channel.ResponseChannelConsumer; import io.vlingo.wire.message.ConsumerByteBuffer; import io.vlingo.wire.message.Converters; public class TestResponseChannelConsumer implements ResponseChannelConsumer { - public int currentExpectedResponseLength; - public int consumeCount; - public List responses = new ArrayList<>(); - public TestUntil untilConsume; - + private static final String ID_RESPONSES = "responses"; + + private final int happenings; + private final int currentExpectedResponseLength; + private final AccessSafely accessSafely; + private final StringBuilder responseBuilder = new StringBuilder(); - + + public TestResponseChannelConsumer(int happenings, int currentExpectedResponseLength) { + this.happenings = happenings; + this.currentExpectedResponseLength = currentExpectedResponseLength; + List responses = new ArrayList<>(); + this.accessSafely = AccessSafely.afterCompleting(happenings) + .writingWith(ID_RESPONSES, msg -> responses.add((String) msg)) + .readingWith(ID_RESPONSES, () -> new ArrayList<>(responses)); + } + @Override public void consume(final ConsumerByteBuffer buffer) { final String responsePart = Converters.bytesToText(buffer.array(), 0, buffer.limit()); @@ -40,14 +50,23 @@ public void consume(final ConsumerByteBuffer buffer) { while (!last) { final String request = combinedResponse.substring(currentIndex, currentIndex+currentExpectedResponseLength); currentIndex += currentExpectedResponseLength; - - responses.add(request); - ++consumeCount; - + + accessSafely.writeUsing(ID_RESPONSES, request); + last = currentIndex == combinedLength; - - untilConsume.happened(); } } } + + public int remaining() { + return happenings - accessSafely.totalWrites(); + } + + public int consumeCount() { + return accessSafely.totalWrites(); + } + + public List responses() { + return accessSafely.readFrom(ID_RESPONSES); + } } diff --git a/src/test/java/io/vlingo/wire/fdx/bidirectional/rsocket/RSocketServerChannelActorTest.java b/src/test/java/io/vlingo/wire/fdx/bidirectional/rsocket/RSocketServerChannelActorTest.java index 0ded1b1..44d38fb 100644 --- a/src/test/java/io/vlingo/wire/fdx/bidirectional/rsocket/RSocketServerChannelActorTest.java +++ b/src/test/java/io/vlingo/wire/fdx/bidirectional/rsocket/RSocketServerChannelActorTest.java @@ -9,225 +9,35 @@ import io.vlingo.actors.Definition; import io.vlingo.actors.Logger; import io.vlingo.actors.World; -import io.vlingo.actors.testkit.TestUntil; +import io.vlingo.wire.channel.RequestChannelConsumerProvider; +import io.vlingo.wire.channel.ResponseChannelConsumer; +import io.vlingo.wire.fdx.bidirectional.AbstractServerChannelActorTest; import io.vlingo.wire.fdx.bidirectional.ClientRequestResponseChannel; import io.vlingo.wire.fdx.bidirectional.ServerRequestResponseChannel; -import io.vlingo.wire.fdx.bidirectional.TestRequestChannelConsumer; -import io.vlingo.wire.fdx.bidirectional.TestRequestChannelConsumerProvider; -import io.vlingo.wire.fdx.bidirectional.TestResponseChannelConsumer; import io.vlingo.wire.node.Address; import io.vlingo.wire.node.AddressType; import io.vlingo.wire.node.Host; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; import java.nio.ByteBuffer; import java.time.Duration; import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; - -public class RSocketServerChannelActorTest { - private static final int POOL_SIZE = 100; - private static AtomicInteger TEST_PORT = new AtomicInteger(49560); - - private ClientRequestResponseChannel client; - private TestResponseChannelConsumer clientConsumer; - private TestRequestChannelConsumerProvider provider; - private ServerRequestResponseChannel server; - private TestRequestChannelConsumer serverConsumer; - private World world; - - @Test - public void testBasicRequestResponse() throws Exception { - final String request = "Hello, Request-Response"; - - serverConsumer.currentExpectedRequestLength = request.length(); - clientConsumer.currentExpectedResponseLength = serverConsumer.currentExpectedRequestLength; - request(request); - - serverConsumer.untilConsume = TestUntil.happenings(1); - clientConsumer.untilConsume = TestUntil.happenings(1); - - while (serverConsumer.untilConsume.remaining() > 0) { - ; - } - serverConsumer.untilConsume.completes(); - - while (clientConsumer.untilConsume.remaining() > 0) { - client.probeChannel(); - } - clientConsumer.untilConsume.completes(); - - assertFalse(serverConsumer.requests.isEmpty()); - assertEquals(1, serverConsumer.consumeCount); - assertEquals(serverConsumer.consumeCount, serverConsumer.requests.size()); - - assertFalse(clientConsumer.responses.isEmpty()); - assertEquals(1, clientConsumer.consumeCount); - assertEquals(clientConsumer.consumeCount, clientConsumer.responses.size()); - - assertEquals(clientConsumer.responses.get(0), serverConsumer.requests.get(0)); - } - - @Test - public void testGappyRequestResponse() throws Exception { - final String requestPart1 = "Request Part-1"; - final String requestPart2 = ":Request Part-2"; - final String requestPart3 = ":Request Part-3"; - - serverConsumer.currentExpectedRequestLength = requestPart1.length() + requestPart2.length() + requestPart3.length(); - clientConsumer.currentExpectedResponseLength = serverConsumer.currentExpectedRequestLength; - - // simulate network latency for parts of single request - - request(requestPart1); - Thread.sleep(100); - request(requestPart2); - Thread.sleep(200); - request(requestPart3); - serverConsumer.untilConsume = TestUntil.happenings(1); - while (serverConsumer.untilConsume.remaining() > 0) { - ; - } - serverConsumer.untilConsume.completes(); - - clientConsumer.untilConsume = TestUntil.happenings(1); - while (clientConsumer.untilConsume.remaining() > 0) { - Thread.sleep(10); - client.probeChannel(); - } - clientConsumer.untilConsume.completes(); - - assertFalse(serverConsumer.requests.isEmpty()); - assertEquals(1, serverConsumer.consumeCount); - assertEquals(serverConsumer.consumeCount, serverConsumer.requests.size()); - - assertFalse(clientConsumer.responses.isEmpty()); - assertEquals(1, clientConsumer.consumeCount); - assertEquals(clientConsumer.consumeCount, clientConsumer.responses.size()); - - assertEquals(clientConsumer.responses.get(0), serverConsumer.requests.get(0)); - } - - - @Test - public void test10RequestResponse() throws Exception { - final String request = "Hello, Request-Response"; - - serverConsumer.currentExpectedRequestLength = request.length() + 1; // digits 0 - 9 - clientConsumer.currentExpectedResponseLength = serverConsumer.currentExpectedRequestLength; - - serverConsumer.untilConsume = TestUntil.happenings(10); - clientConsumer.untilConsume = TestUntil.happenings(10); - - for (int idx = 0; idx < 10; ++idx) { - request(request + idx); - } - - while (clientConsumer.untilConsume.remaining() > 0) { - client.probeChannel(); - } - - serverConsumer.untilConsume.completes(); - clientConsumer.untilConsume.completes(); - - assertFalse(serverConsumer.requests.isEmpty()); - assertEquals(10, serverConsumer.consumeCount); - assertEquals(serverConsumer.consumeCount, serverConsumer.requests.size()); - - assertFalse(clientConsumer.responses.isEmpty()); - assertEquals(10, clientConsumer.consumeCount); - assertEquals(clientConsumer.consumeCount, clientConsumer.responses.size()); - - for (int idx = 0; idx < 10; ++idx) { - assertEquals(clientConsumer.responses.get(idx), serverConsumer.requests.get(idx)); - } - } - - @Test - public void testThatRequestResponsePoolLimitsNotExceeded() throws Exception { - final int TOTAL = POOL_SIZE * 2; - - final String request = "Hello, Request-Response"; - - serverConsumer.currentExpectedRequestLength = request.length() + 3; // digits 000 - 999 - clientConsumer.currentExpectedResponseLength = serverConsumer.currentExpectedRequestLength; - - serverConsumer.untilConsume = TestUntil.happenings(TOTAL); - clientConsumer.untilConsume = TestUntil.happenings(TOTAL); - - for (int idx = 0; idx < TOTAL; ++idx) { - request(request + String.format("%03d", idx)); - } - - while (clientConsumer.untilConsume.remaining() > 0) { - client.probeChannel(); - } - serverConsumer.untilConsume.completes(); - clientConsumer.untilConsume.completes(); - - assertFalse(serverConsumer.requests.isEmpty()); - assertEquals(TOTAL, serverConsumer.consumeCount); - assertEquals(serverConsumer.consumeCount, serverConsumer.requests.size()); - - assertFalse(clientConsumer.responses.isEmpty()); - assertEquals(TOTAL, clientConsumer.consumeCount); - assertEquals(clientConsumer.consumeCount, clientConsumer.responses.size()); - - for (int idx = 0; idx < TOTAL; ++idx) { - assertEquals(clientConsumer.responses.get(idx), serverConsumer.requests.get(idx)); - } - } - - @Before - public void setUp() throws Exception { - world = World.startWithDefaults("test-request-response-channel"); +public class RSocketServerChannelActorTest extends AbstractServerChannelActorTest { + protected ClientRequestResponseChannel createClient(ResponseChannelConsumer consumer) { final Logger logger = Logger.basicLogger(); - provider = new TestRequestChannelConsumerProvider(); - serverConsumer = (TestRequestChannelConsumer) provider.consumer; - - final int testPort = TEST_PORT.incrementAndGet(); - - final List params = Definition.parameters(provider, testPort, "test-server", POOL_SIZE, 10240); - - server = world.actorFor( - ServerRequestResponseChannel.class, - Definition.has(RSocketServerChannelActor.class, params)); - - - clientConsumer = new TestResponseChannelConsumer(); - - client = new RSocketClientChannel(Address.from(Host.of("127.0.0.1"), testPort, AddressType.NONE), clientConsumer, POOL_SIZE, 10240, logger, Duration.ofSeconds(1)); + return new RSocketClientChannel(Address.from(Host.of("127.0.0.1"), testPort, AddressType.NONE), + consumer, POOL_SIZE, 10240, logger, Duration.ofSeconds(1)); } - @After - public void tearDown() { - try { - server.close(); - } catch (Exception e) { - // ignore - } - try { - client.close(); - } catch (Exception e) { - // ignore - } - - try { Thread.sleep(1000); } catch (Exception e) { } - - try { - world.terminate(); - } catch (Exception e) { - // ignore - } + protected ServerRequestResponseChannel createServer(World world, RequestChannelConsumerProvider consumerProvider) { + final List params = Definition.parameters(consumerProvider, testPort, "test-server", POOL_SIZE, 10240); + return world.actorFor( + ServerRequestResponseChannel.class, + Definition.has(RSocketServerChannelActor.class, params)); } - private void request(final String request) { + protected void request(ClientRequestResponseChannel client, final String request) { client.requestWith(ByteBuffer.wrap(request.getBytes())); } diff --git a/src/test/java/io/vlingo/wire/fdx/inbound/InboundStreamTest.java b/src/test/java/io/vlingo/wire/fdx/inbound/InboundStreamTest.java index 8606d6a..fde9fca 100644 --- a/src/test/java/io/vlingo/wire/fdx/inbound/InboundStreamTest.java +++ b/src/test/java/io/vlingo/wire/fdx/inbound/InboundStreamTest.java @@ -16,7 +16,6 @@ import io.vlingo.actors.Definition; import io.vlingo.actors.testkit.TestActor; -import io.vlingo.actors.testkit.TestUntil; import io.vlingo.actors.testkit.TestWorld; import io.vlingo.wire.channel.MockChannelReader; import io.vlingo.wire.message.AbstractMessageTool; @@ -30,18 +29,16 @@ public class InboundStreamTest extends AbstractMessageTool { @Test public void testInbound() throws Exception { - interest.testResults.untilStops = TestUntil.happenings(1); while (reader.probeChannelCount.get() == 0) ; inboundStream.actor().stop(); int count = 0; - for (final String message : interest.testResults.messages) { + for (final String message : interest.testResults.getMessages()) { ++count; assertEquals(MockChannelReader.MessagePrefix + count, message); } - interest.testResults.untilStops.completes(); - - assertTrue(interest.testResults.messageCount.get() > 0); + + assertTrue(interest.testResults.messageCount() > 0); assertEquals(count, reader.probeChannelCount.get()); } @@ -49,7 +46,7 @@ public void testInbound() throws Exception { public void setUp() throws Exception { world = TestWorld.start("test-inbound-stream"); - interest = new MockInboundStreamInterest(); + interest = new MockInboundStreamInterest(1); reader = new MockChannelReader(); diff --git a/src/test/java/io/vlingo/wire/fdx/inbound/MockInboundStreamInterest.java b/src/test/java/io/vlingo/wire/fdx/inbound/MockInboundStreamInterest.java index 22e3a62..857960f 100644 --- a/src/test/java/io/vlingo/wire/fdx/inbound/MockInboundStreamInterest.java +++ b/src/test/java/io/vlingo/wire/fdx/inbound/MockInboundStreamInterest.java @@ -7,32 +7,58 @@ package io.vlingo.wire.fdx.inbound; +import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.atomic.AtomicInteger; -import io.vlingo.actors.testkit.TestUntil; +import io.vlingo.actors.testkit.AccessSafely; import io.vlingo.wire.message.AbstractMessageTool; import io.vlingo.wire.message.RawMessage; import io.vlingo.wire.node.AddressType; public class MockInboundStreamInterest extends AbstractMessageTool implements InboundStreamInterest { - public TestResults testResults = new TestResults(); + public TestResults testResults; - public MockInboundStreamInterest() { } + public MockInboundStreamInterest(int happenings) { + testResults = new TestResults(happenings); + } @Override public void handleInboundStreamMessage(final AddressType addressType, final RawMessage message) { final String textMessage = message.asTextMessage(); - testResults.messages.add(textMessage); - testResults.messageCount.incrementAndGet(); - System.out.println("INTEREST: " + textMessage + " list-size: " + testResults.messages.size() + " count: " + testResults.messageCount.get() + " count-down: " + testResults.untilStops.remaining()); - testResults.untilStops.happened(); + testResults.addMessage(textMessage); + System.out.printf("INTEREST: %s; count: %s; count-down: %s%n", + textMessage, testResults.messageCount(), testResults.remaining()); } static class TestResults { - public final AtomicInteger messageCount = new AtomicInteger(0); - public final List messages = new CopyOnWriteArrayList<>(); - public TestUntil untilStops; + private static final String ID_MESSAGES = "messages"; + + private final int happenings; + private final AccessSafely accessSafely; + + TestResults(int happenings) { + this.happenings = happenings; + List messages = new ArrayList<>(); + this.accessSafely = AccessSafely + .afterCompleting(happenings) + .writingWith(ID_MESSAGES, msg -> messages.add((String) msg)) + .readingWith(ID_MESSAGES, () -> new ArrayList<>(messages)); + } + + void addMessage(String msg) { + accessSafely.writeUsing(ID_MESSAGES, msg); + } + + List getMessages() { + return accessSafely.readFrom(ID_MESSAGES); + } + + int messageCount() { + return accessSafely.totalWrites(); + } + + int remaining() { + return happenings - messageCount(); + } } }