diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index cdcf934fc..50fe6d1b1 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -78,7 +78,8 @@ jobs: strategy: fail-fast: false matrix: - WEAVIATE_VERSION: ["1.32.24", "1.33.11", "1.34.7", "1.35.2"] + WEAVIATE_VERSION: + ["1.32.24", "1.33.11", "1.34.7", "1.35.2", "1.36.0-rc.0"] steps: - uses: actions/checkout@v4 diff --git a/pom.xml b/pom.xml index 7a98a1a84..8414bc832 100644 --- a/pom.xml +++ b/pom.xml @@ -58,20 +58,20 @@ 3.20.0 4.13.2 2.0.3 - 3.27.6 + 3.27.7 1.0.4 5.21.0 2.0.17 1.5.18 5.14.0 2.21 - 11.31.1 + 11.33 5.15.0 - 4.33.4 - 4.33.4 - 1.78.0 - 1.78.0 - 1.78.0 + 4.33.5 + 4.33.5 + 1.79.0 + 1.79.0 + 1.79.0 6.0.53 diff --git a/src/it/java/io/weaviate/containers/Weaviate.java b/src/it/java/io/weaviate/containers/Weaviate.java index 0f463e96f..c12bf9f6e 100644 --- a/src/it/java/io/weaviate/containers/Weaviate.java +++ b/src/it/java/io/weaviate/containers/Weaviate.java @@ -26,7 +26,7 @@ public class Weaviate extends WeaviateContainer { public static final String DOCKER_IMAGE = "semitechnologies/weaviate"; - public static final String LATEST_VERSION = Version.V135.semver.toString(); + public static final String LATEST_VERSION = Version.latest().semver.toString(); public static final String VERSION; static { @@ -41,7 +41,8 @@ public enum Version { V132(1, 32, 24), V133(1, 33, 11), V134(1, 34, 7), - V135(1, 35, 2); + V135(1, 35, 2), + V136(1, 36, "0-rc.0"); public final SemanticVersion semver; @@ -49,9 +50,21 @@ private Version(int major, int minor, int patch) { this.semver = new SemanticVersion(major, minor, patch); } + private Version(int major, int minor, String patch) { + this.semver = new SemanticVersion(major, minor, patch); + } + public void orSkip() { ConcurrentTest.requireAtLeast(this); } + + public static Version latest() { + Version[] versions = Version.class.getEnumConstants(); + if (versions == null) { + throw new IllegalStateException("No versions are defined"); + } + return versions[versions.length - 1]; + } } /** diff --git a/src/it/java/io/weaviate/integration/BatchITest.java b/src/it/java/io/weaviate/integration/BatchITest.java new file mode 100644 index 000000000..0e57c914c --- /dev/null +++ b/src/it/java/io/weaviate/integration/BatchITest.java @@ -0,0 +1,48 @@ +package io.weaviate.integration; + +import java.io.IOException; +import java.util.Map; +import java.util.UUID; + +import org.assertj.core.api.Assertions; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.weaviate.ConcurrentTest; +import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.containers.Container; +import io.weaviate.containers.Weaviate; + +public class BatchITest extends ConcurrentTest { + private static final WeaviateClient client = Container.WEAVIATE.getClient(); + + @BeforeClass + public static void __() { + Weaviate.Version.V136.orSkip(); + } + + @Test + public void test() throws IOException { + var nsThings = ns("Things"); + + var things = client.collections.create( + nsThings, + c -> c.properties(Property.text("letter"))); + + // Act + try (var batch = things.batch.start()) { + for (int i = 0; i < 10_000; i++) { + String uuid = UUID.randomUUID().toString(); + batch.add(WeaviateObject.of(builder -> builder + .uuid(uuid) + .properties(Map.of("letter", uuid.substring(0, 1))))); + } + } catch (InterruptedException e) { + } + + // Assert + Assertions.assertThat(things.size()).isEqualTo(10_000); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/Config.java b/src/main/java/io/weaviate/client6/v1/api/Config.java index b52b37066..33baabb63 100644 --- a/src/main/java/io/weaviate/client6/v1/api/Config.java +++ b/src/main/java/io/weaviate/client6/v1/api/Config.java @@ -11,6 +11,7 @@ import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.Timeout; import io.weaviate.client6.v1.internal.TokenProvider; +import io.weaviate.client6.v1.internal.TransportOptions; import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; import io.weaviate.client6.v1.internal.rest.RestTransportOptions; @@ -181,26 +182,15 @@ public SelfT timeout(int initSeconds, int querySeconds, int insertSeconds) { private static final String HEADER_X_WEAVIATE_CLUSTER_URL = "X-Weaviate-Cluster-URL"; private static final String HEADER_X_WEAVIATE_CLIENT = "X-Weaviate-Client"; - /** - * isWeaviateDomain returns true if the host matches weaviate.io, - * semi.technology, or weaviate.cloud domain. - */ - private static boolean isWeaviateDomain(String host) { - var lower = host.toLowerCase(); - return lower.contains("weaviate.io") || - lower.contains("semi.technology") || - lower.contains("weaviate.cloud"); - } - private static final String VERSION = "weaviate-client-java/" - + ((!BuildInfo.TAGS.isBlank() && BuildInfo.TAGS != "unknown") ? BuildInfo.TAGS - : (BuildInfo.BRANCH + "-" + BuildInfo.COMMIT_ID_ABBREV)); + + ((!BuildInfo.TAGS.isBlank() && BuildInfo.TAGS != "unknown") ? BuildInfo.TAGS + : (BuildInfo.BRANCH + "-" + BuildInfo.COMMIT_ID_ABBREV)); @Override public Config build() { // For clusters hosted on Weaviate Cloud, Weaviate Embedding Service // will be available under the same domain. - if (isWeaviateDomain(httpHost) && authentication != null) { + if (TransportOptions.isWeaviateDomain(httpHost) && authentication != null) { setHeader(HEADER_X_WEAVIATE_CLUSTER_URL, "https://" + httpHost + ":" + httpPort); } setHeader(HEADER_X_WEAVIATE_CLIENT, VERSION); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java index 7af8ed549..b49ec987a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java @@ -1,9 +1,11 @@ package io.weaviate.client6.v1.api.collections; import java.util.Collection; +import java.util.Optional; import java.util.function.Function; import io.weaviate.client6.v1.api.collections.aggregate.WeaviateAggregateClient; +import io.weaviate.client6.v1.api.collections.batch.WeaviateBatchClient; import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClient; import io.weaviate.client6.v1.api.collections.data.WeaviateDataClient; import io.weaviate.client6.v1.api.collections.generate.WeaviateGenerateClient; @@ -23,6 +25,7 @@ public class CollectionHandle { public final WeaviateAggregateClient aggregate; public final WeaviateGenerateClient generate; public final WeaviateTenantsClient tenants; + public final WeaviateBatchClient batch; private final CollectionHandleDefaults defaults; @@ -36,6 +39,7 @@ public CollectionHandle( this.query = new WeaviateQueryClient<>(collection, grpcTransport, defaults); this.generate = new WeaviateGenerateClient<>(collection, grpcTransport, defaults); this.data = new WeaviateDataClient<>(collection, restTransport, grpcTransport, defaults); + this.batch = new WeaviateBatchClient<>(grpcTransport, collection, defaults); this.defaults = defaults; this.tenants = new WeaviateTenantsClient(collection, restTransport, grpcTransport); @@ -48,6 +52,7 @@ private CollectionHandle(CollectionHandle c, CollectionHandleDefaul this.query = new WeaviateQueryClient<>(c.query, defaults); this.generate = new WeaviateGenerateClient<>(c.generate, defaults); this.data = new WeaviateDataClient<>(c.data, defaults); + this.batch = new WeaviateBatchClient<>(c.batch, defaults); this.defaults = defaults; this.tenants = c.tenants; @@ -112,7 +117,7 @@ public long size() { } /** Default consistency level for requests. */ - public ConsistencyLevel consistencyLevel() { + public Optional consistencyLevel() { return defaults.consistencyLevel(); } @@ -122,7 +127,7 @@ public CollectionHandle withConsistencyLevel(ConsistencyLevel consi } /** Default tenant for requests. */ - public String tenant() { + public Optional tenant() { return defaults.tenant(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java index 83d18ed2f..14c551d18 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java @@ -1,6 +1,7 @@ package io.weaviate.client6.v1.api.collections; import java.util.Collection; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.Function; @@ -90,7 +91,7 @@ public CompletableFuture size() { } /** Default consistency level for requests. */ - public ConsistencyLevel consistencyLevel() { + public Optional consistencyLevel() { return defaults.consistencyLevel(); } @@ -101,7 +102,7 @@ public CollectionHandleAsync withConsistencyLevel(ConsistencyLevel } /** Default tenant for requests. */ - public String tenant() { + public Optional tenant() { return defaults.tenant(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java index c7952222a..47ee0dcba 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaults.java @@ -1,14 +1,17 @@ package io.weaviate.client6.v1.api.collections; +import static java.util.Objects.requireNonNull; + import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.function.Function; import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; import io.weaviate.client6.v1.internal.ObjectBuilder; -public record CollectionHandleDefaults(ConsistencyLevel consistencyLevel, String tenant) { +public record CollectionHandleDefaults(Optional consistencyLevel, Optional tenant) { /** * Set default values for query / aggregation requests. * @@ -28,8 +31,12 @@ public static Function> none() return ObjectBuilder.identity(); } + public CollectionHandleDefaults { + requireNonNull(consistencyLevel, "consistencyLevel is null"); + } + public CollectionHandleDefaults(Builder builder) { - this(builder.consistencyLevel, builder.tenant); + this(Optional.ofNullable(builder.consistencyLevel), Optional.ofNullable(builder.tenant)); } public static final class Builder implements ObjectBuilder { @@ -56,16 +63,12 @@ public CollectionHandleDefaults build() { /** Serialize default values to a URL query. */ public Map queryParameters() { - if (consistencyLevel == null && tenant == null) { + if (consistencyLevel.isEmpty() && tenant.isEmpty()) { return Collections.emptyMap(); } - var query = new HashMap(); - if (consistencyLevel != null) { - query.put("consistency_level", consistencyLevel); - } - if (tenant != null) { - query.put("tenant", tenant); - } + Map query = new HashMap(); + consistencyLevel.ifPresent(v -> query.put("consistency_level", v)); + tenant.ifPresent(v -> query.put("tenant", v)); return query; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java index fa8290f64..7d048937a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java @@ -25,9 +25,7 @@ static Rpc { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java new file mode 100644 index 000000000..f3f7a1282 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Batch.java @@ -0,0 +1,350 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.time.Instant; +import java.util.Collection; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.ListIterator; +import java.util.OptionalInt; +import java.util.Set; +import java.util.TreeSet; + +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +import io.weaviate.client6.v1.api.collections.batch.Event.Backoff; +import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; + +/** + * Batch can be in either of 2 states: + *
    + *
  • Open batch accepts new items and can be resized. + *
  • In-flight batch is sealed: it rejects new items and + * avoids otherwise modifying the {@link #buffer} until it's cleared. + *
+ * + *

Class invariants

+ * + * {@link #maxSize} and {@link #maxSizeBytes} MUST be positive. + * A batch with {@code cap=0} is not useful.
+ * {@link #buffer} size and {@link #sizeBytes} MUST be non-negative.
+ * {@link #buffer} size MUST NOT exceed {@link #maxSize}.
+ * {@link #sizeBytes} MUST NOT exceed {@link #maxSize}.
+ * {@link #sizeBytes} MUST be 0 if the buffer is full.
+ * {@link #backlog} MAY only contain items when {@link #buffer} is full. In the + * {@link #pendingMaxSize} is empty for an open batch.
+ * edge-case + * + * + *

Synchronization policy

+ * + * @see #inFlight + * @see #isFull + * @see #clear + * @see #checkInvariants + */ +@ThreadSafe +final class Batch { + /** Backlog MUST be confined to the "receiver" thread. */ + private final TreeSet backlog = new TreeSet<>(BacklogItem.comparator()); + + /** + * Items stored in this batch. + */ + @GuardedBy("this") + private final LinkedHashMap buffer; + + /** + * Maximum number of items that can be added to the request. + * Must be greater that zero. + * + *

+ * This is determined by the server's {@link Backoff} instruction. + */ + @GuardedBy("this") + private int maxSize; + + /** + * Maximum size of the serialized message in bytes. + * Must be greater that zero. + * + *

+ * This is determined by the {@link GrpcChannelOptions#maxMessageSize}. + */ + @GuardedBy("this") + private final long maxSizeBytes; + + /** Total serialized size of the items in the {@link #buffer}. */ + @GuardedBy("this") + private long sizeBytes; + + /** An in-flight batch is unmodifiable. */ + @GuardedBy("this") + private boolean inFlight = false; + + /** + * Pending update to the {@link #maxSize}. + * + * The value is non-empty when {@link #setMaxSize} is called + * while the batch is {@link #inFlight}. + */ + @GuardedBy("this") + private OptionalInt pendingMaxSize = OptionalInt.empty(); + + Batch(int maxSize, int maxSizeBytes) { + assert maxSize > 0 : "non-positive maxSize"; + + this.maxSizeBytes = MessageSizeUtil.maxSizeBytes(maxSizeBytes); + this.maxSize = maxSize; + this.buffer = new LinkedHashMap<>(maxSize); // LinkedHashMap preserves insertion order. + + checkInvariants(); + } + + /** + * Returns true if batch has reached its capacity, either in terms + * of the item count or the batch's estimated size in bytes. + */ + synchronized boolean isFull() { + return buffer.size() == maxSize || sizeBytes == maxSizeBytes; + } + + /** + * Returns true if the batch's internal buffer is empty. + * If it's primary buffer is empty, its backlog is guaranteed + * to be empty as well. + */ + synchronized boolean isEmpty() { + return buffer.isEmpty(); // sizeBytes == 0 is guaranteed by class invariant. + } + + /** + * Prepare a request to be sent. After calling this method, this batch becomes + * "in-flight": an attempt to {@link #add} more items to it will be rejected + * with an exception. + */ + synchronized Message prepare() { + checkInvariants(); + + inFlight = true; + return builder -> { + buffer.forEach((__, data) -> { + data.appendTo(builder); + }); + }; + } + + /** + * Set the new {@link #maxSize} for this buffer. + * + *

+ * How the size is applied depends of the buffer's current state: + *

    + *
  • When the batch is in-flight, the new limit is stored in + * {@link #pendingMaxSize} and will be applied once the batch is cleared. + *
  • While the batch is still open, the new limit is applied immediately and + * the {@link #pendingMaxSize} is set back to {@link OptionalInt#empty}. If + * the current buffer size exceeds the new limit, the overflow items are moved + * to the {@link #backlog}. + *
+ * + * @param maxSizeNew New batch size limit. + * + * @see #clear + */ + synchronized void setMaxSize(int maxSizeNew) { + checkInvariants(); + + try { + // In-flight batch cannot be modified. + // Store the requested maxSize for later; + // it will be applied on the next ack. + if (inFlight) { + pendingMaxSize = OptionalInt.of(maxSizeNew); + return; + } + + maxSize = maxSizeNew; + pendingMaxSize = OptionalInt.empty(); + + // Buffer still fits under the new limit. + if (buffer.size() <= maxSize) { + return; + } + + // Buffer exceeds the new limit. Move extra items to the backlog (LIFO). + ListIterator extra = buffer.keySet().stream().toList().listIterator(); + while (extra.hasPrevious() && buffer.size() > maxSize) { + addBacklog(buffer.remove(extra.previous())); + } + } finally { + checkInvariants(); + } + } + + /** + * Add a data item to the batch. + * + *

+ * We want to guarantee that, once a work item has been taken from the queue, + * it's going to be eventually executed. Because we cannot know if an item + * will overflow the batch before it's removed from the queue, the simplest + * and safest way to deal with it is to allow {@link Batch} to put + * the overflowing item in the {@link #backlog}. The batch is considered + * full after that. + * + * @throws DataTooBigException If the data exceeds the maximum + * possible batch size. + * @throws IllegalStateException If called on an "in-flight" batch. + * @see #prepare + * @see #inFlight + * @see #clear + */ + synchronized void add(Data data) throws IllegalStateException, DataTooBigException { + requireNonNull(data, "data is null"); + checkInvariants(); + + try { + if (inFlight) { + throw new IllegalStateException("Batch is in-flight"); + } + long remainingBytes = maxSizeBytes - sizeBytes; + if (data.sizeBytes() <= remainingBytes) { + addSafe(data); + return; + } + if (isEmpty()) { + throw new DataTooBigException(data, maxSizeBytes); + } + // One of the class's invariants is that the backlog must not contain + // any items unless the buffer is full. In case this item overflows + // the buffer, we put it in the backlog, but pretend the maxSizeBytes + // has been reached to satisfy the invariant. + // This doubles as a safeguard to ensure the caller cannot add any + // more items to the batch before flushing it. + addBacklog(data); + sizeBytes += remainingBytes; + assert isFull() : "batch must be full after an overflow"; + } finally { + checkInvariants(); + } + } + + /** + * Add a data item to the batch. + * + * This method does not check {@link Data#sizeBytes}, so the caller + * must ensure that this item will not overflow the batch. + */ + private synchronized void addSafe(Data data) { + buffer.put(data.id(), data); + sizeBytes += data.sizeBytes(); + } + + /** Add a data item to the {@link #backlog}. */ + private synchronized void addBacklog(Data data) { + backlog.add(new BacklogItem(data)); + } + + /** + * Clear this batch's internal buffer. + * + *

+ * Once the buffer is pruned, it is re-populated from the backlog + * until the former is full or the latter is exhaused. + * If {@link #pendingMaxSize} is not empty, it is applied + * before re-populating the buffer. + * + * @return IDs removed from the buffer. + */ + synchronized Collection clear() { + checkInvariants(); + + try { + inFlight = false; + + Set removed = Set.copyOf(buffer.keySet()); + buffer.clear(); + sizeBytes = 0; + + if (pendingMaxSize.isPresent()) { + setMaxSize(pendingMaxSize.getAsInt()); + } + + // Populate internal buffer from the backlog. + // We don't need to check the return value of .add(), + // as all items in the backlog are guaranteed to not + // exceed maxSizeBytes. + backlog.stream() + .takeWhile(__ -> !isFull()) + .map(BacklogItem::data) + .forEach(this::addSafe); + + return removed; + } finally { + checkInvariants(); + } + } + + private static record BacklogItem(Data data, Instant createdAt) { + public BacklogItem { + requireNonNull(data, "data is null"); + requireNonNull(createdAt, "createdAt is null"); + } + + /** + * This constructor sets {@link #createdAt} automatically. + * It is not important that this timestamp is different from + * the one in {@link TaskHandle}, as longs as the order is correct. + */ + public BacklogItem(Data data) { + this(data, Instant.now()); + } + + /** Comparator sorts BacklogItems by their creation time. */ + private static Comparator comparator() { + return new Comparator() { + + @Override + public int compare(BacklogItem a, BacklogItem b) { + if (a.equals(b)) { + return 0; + } + + int cmpInstant = a.createdAt.compareTo(b.createdAt); + boolean sameInstant = cmpInstant == 0; + if (sameInstant) { + // We cannot return 0 for two items with different + // contents, as it may result in data loss. + // If they were somehow created in the same instant, + // let them be sorted lexicographically. + return a.data.id().compareTo(b.data.id()); + } + return cmpInstant; + } + }; + } + } + + /** Asserts the invariants of this class. */ + private synchronized void checkInvariants() { + assert maxSize > 0 : "non-positive maxSize"; + assert maxSizeBytes > 0 : "non-positive maxSizeBytes"; + assert sizeBytes >= 0 : "negative sizeBytes"; + assert buffer.size() <= maxSize : "buffer exceeds maxSize"; + assert sizeBytes <= maxSizeBytes : "message exceeds maxSizeBytes"; + if (!isFull()) { + assert backlog.isEmpty() : "backlog not empty when buffer not full"; + } + if (buffer.isEmpty()) { + assert sizeBytes == 0 : "sizeBytes must be 0 when buffer is empty"; + } + + requireNonNull(pendingMaxSize, "pendingMaxSize is null"); + if (!inFlight) { + assert pendingMaxSize.isEmpty() : "open batch has pending maxSize"; + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java new file mode 100644 index 000000000..d52745976 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java @@ -0,0 +1,891 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.EnumSet; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.BiConsumer; + +import javax.annotation.concurrent.GuardedBy; + +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.batch.Event.ClientError; +import io.weaviate.client6.v1.api.collections.batch.Event.StreamHangup; +import io.weaviate.client6.v1.api.collections.data.BatchReference; +import io.weaviate.client6.v1.api.collections.data.InsertManyRequest; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +/** + * BatchContext stores the state of an active batch process + * and controls its lifecycle. + * + *

State

+ * + *

Lifecycle

+ * + *

Cancellation policy

+ * + * @param the shape of properties for inserted objects. + */ +public final class BatchContext implements Closeable { + private final int DEFAULT_BATCH_SIZE = 1_000; + private final int DEFAULT_QUEUE_SIZE = 100; + private final int MAX_RECONNECT_RETRIES = 5; + + private final CollectionDescriptor collectionDescriptor; + private final CollectionHandleDefaults collectionHandleDefaults; + + /** + * Internal execution service. It's lifecycle is bound to that of the + * BatchContext: it's started when the context is initialized + * and shutdown on {@link #close}. + * + *

+ * In the event of abrupt stream termination ({@link Recv#onError} is called), + * the "recv" thread MAY shutdown this service in order to interrupt the "send" + * thread; the latter may be blocked on {@link Send#awaitCanSend} or + * {@link Send#awaitCanPrepareNext}. + */ + private final ExecutorService sendExec = Executors.newSingleThreadExecutor(); + + /** + * Scheduled thread pool for delayed tasks. + * + * @see Oom + * @see Reconnecting + */ + private final ScheduledExecutorService scheduledExec = Executors.newScheduledThreadPool(1); + + /** The thread that created the context. */ + private final Thread parent = Thread.currentThread(); + + /** Stream factory creates new streams. */ + private final StreamFactory streamFactory; + + /** + * Queue publishes insert tasks from the main thread to the "sender". + * It has a maximum capacity of {@link #DEFAULT_QUEUE_SIZE}. + * + * Send {@link TaskHandle#POISON} to gracefully shutdown the "sender" + * thread. The same queue may be re-used with a different "sender", + * e.g. after {@link #reconnect}, but only when the new thread is known + * to have started. Otherwise the thread trying to put an item on + * the queue will block indefinitely. + */ + private final BlockingQueue queue; + + /** + * Work-in-progress items. + * + * An item is added to the {@link #wip} map after the Sender successfully + * adds it to the {@link #batch} and is removed once the server reports + * back the result (whether success of failure). + */ + private final ConcurrentMap wip = new ConcurrentHashMap<>(); + + /** + * Current batch. + * + *

+ * An item is added to the {@link #batch} after the Sender pulls it + * from the queue and remains there until it's Ack'ed. + */ + private final Batch batch; + + /** + * State encapsulates state-dependent behavior of the {@link BatchContext}. + * Before reading {@link #state}, a thread MUST acquire {@link #lock}. + */ + @GuardedBy("lock") + private State state; + /** lock synchronizes access to {@link #state}. */ + private final Lock lock = new ReentrantLock(); + /** stateChanged notifies threads about a state transition. */ + private final Condition stateChanged = lock.newCondition(); + + /** + * Client-side part of the current stream, created on {@link #start}. + * Other threads MAY use stream but MUST NOT update this field on their own. + */ + private volatile StreamObserver messages; + + /** Handle for the "send" thread. Use {@link Future#cancel} to interrupt it. */ + private volatile Future send; + + /** + * Latch reaches zero once both "send" (client side) and "recv" (server side) + * parts of the stream have closed. After a {@link reconnect}, the latch is + * reset. + */ + private volatile CountDownLatch workers; + + /** closing completes the stream. */ + private final CompletableFuture closing = new CompletableFuture<>(); + + /** Executor for performing the shutdown sequence. */ + private final ExecutorService shutdownExec = Executors.newSingleThreadExecutor(); + + /** Lightway check to ensure users cannot send on a closed context. */ + private volatile boolean closed; + + // /** Closing state. */ + // private volatile Closing closing; + + // /** + // * setClosing trasitions BatchContext to {@link Closing} state exactly once. + // * Once this method returns, the caller can call {@code closing.await()}. + // */ + // void setClosing(Exception ex) { + // if (closing == null) { + // synchronized (Closing.class) { + // if (closing == null) { + // closing = new Closing(ex); + // setState(closing); + // } + // } + // } + // } + + BatchContext( + StreamFactory streamFactory, + int maxSizeBytes, + CollectionDescriptor collectionDescriptor, + CollectionHandleDefaults collectionHandleDefaults) { + this.streamFactory = requireNonNull(streamFactory, "streamFactory is null"); + this.collectionDescriptor = requireNonNull(collectionDescriptor, "collectionDescriptor is null"); + this.collectionHandleDefaults = requireNonNull(collectionHandleDefaults, "collectionHandleDefaults is null"); + + this.queue = new ArrayBlockingQueue<>(DEFAULT_QUEUE_SIZE); + this.batch = new Batch(DEFAULT_BATCH_SIZE, maxSizeBytes); + } + + /** Add {@link WeaviateObject} to the batch. */ + public TaskHandle add(WeaviateObject object) throws InterruptedException { + TaskHandle handle = new TaskHandle( + object, + InsertManyRequest.buildObject(object, collectionDescriptor, collectionHandleDefaults)); + return add(handle); + } + + /** Add {@link BatchReference} to the batch. */ + public TaskHandle add(BatchReference reference) throws InterruptedException { + TaskHandle handle = new TaskHandle( + reference, + InsertManyRequest.buildReference(reference, collectionHandleDefaults.tenant())); + return add(handle); + } + + void start() { + start(AWAIT_STARTED); + } + + void start(State nextState) { + workers = new CountDownLatch(2); + + messages = streamFactory.createStream(new Recv()); + + // Start the stream and await Started message. + messages.onNext(Message.start(collectionHandleDefaults.consistencyLevel())); + setState(nextState); + + // "send" routine must start after the nextState has been set. + send = sendExec.submit(new Send()); + } + + /** + * Reconnect waits for "send" and "recv" streams to exit + * and restarts the process with a new stream. + * + * @param reconnecting Reconnecting instance that called reconnect. + */ + void reconnect(Reconnecting reconnecting) throws InterruptedException, ExecutionException { + workers.await(); + send.get(); + start(reconnecting); + } + + /** + * Retry a task. + * + * BatchContext does not impose any limit on the number of times a task can + * be retried -- it is up to the user to implement an appropriate retry policy. + * + * @see TaskHandle#timesRetried + */ + public TaskHandle retry(TaskHandle taskHandle) throws InterruptedException { + return add(taskHandle.retry()); + } + + /** + * Close attempts to drain the queue and send all remaining items. + * Calling any of BatchContext's public methods afterwards will + * result in an {@link IllegalStateException}. + * + * @throws IOException Propagates an exception + * if one has occurred in the meantime. + */ + @Override + public void close() throws IOException { + closed = true; + + try { + shutdown(); + } catch (InterruptedException | ExecutionException e) { + if (e instanceof InterruptedException || + e.getCause() instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new IOException(e.getCause()); + } finally { + shutdownExecutors(); + } + } + + private void shutdown() throws InterruptedException, ExecutionException { + CompletableFuture gracefulShutdown = CompletableFuture.runAsync(() -> { + try { + // Poison the queue -- this will signal "send" to drain the remaing + // items in the batch and in the backlog and exit. + // + // If shutdownNow has been called previously and the "send" routine + // has been interrupted, this may block indefinitely. + // However, shutdownNow ensures that `closing` future is resolved. + queue.put(TaskHandle.POISON); + + // Wait for the send to exit before closing our end of the stream. + send.get(); + messages.onNext(Message.stop()); + messages.onCompleted(); + + // Wait for both "send" and "recv" to exit. + workers.await(); + closing.complete(null); + } catch (Exception e) { + closing.completeExceptionally(e); + } + + }, shutdownExec); + + // Complete shutdown as soon as one of these futures are completed. + // - gracefulShutdown completes if we managed to shutdown normally. + // - closing may complete sooner if shutdownNow is called. + CompletableFuture.anyOf(closing, gracefulShutdown).get(); + } + + private void shutdownNow(Exception ex) { + // Terminate the "send" routine and wait for it to exit. + // Since we're already in the error state we do not care + // much if it throws or not. + send.cancel(true); + try { + send.get(); + } catch (Exception e) { + } + + // Now report this error to the server and close the stream. + closing.completeExceptionally(ex); + messages.onError(Status.INTERNAL.withCause(ex).asRuntimeException()); + + // Since shutdownNow is never triggerred by the "main" thread, + // it may be blocked on trying to add to the queue. While batch + // context is active, we own this thread and may interrupt it. + parent.interrupt(); + } + + private void shutdownExecutors() { + BiConsumer> assertEmpty = (name, pending) -> { + assert pending.isEmpty() : "'%s' service had %d tasks awaiting execution" + .formatted(pending.size(), name); + }; + + List pending; + + pending = sendExec.shutdownNow(); + assertEmpty.accept("send", pending); + + pending = scheduledExec.shutdownNow(); + assertEmpty.accept("oom", pending); + + pending = shutdownExec.shutdownNow(); + assertEmpty.accept("shutdown", pending); + } + + /** Set the new state and notify awaiting threads. */ + void setState(State nextState) { + requireNonNull(nextState, "nextState is null"); + + lock.lock(); + System.out.println("setState " + state + " => " + nextState); + try { + State prev = state; + state = nextState; + state.onEnter(prev); + stateChanged.signal(); + } finally { + lock.unlock(); + } + } + + /** + * onEvent delegates event handling to {@link #state}. + * + *

+ * Be mindful that most of the time this callback will run in a hot path + * on a gRPC thread. {@link State} implementations SHOULD offload any + * blocking operations to one of the provided executors. + * + * @see #scheduledExec + */ + private void onEvent(Event event) { + lock.lock(); + try { + System.out.println("onEvent " + event); + state.onEvent(event); + } finally { + lock.unlock(); + } + } + + private TaskHandle add(final TaskHandle taskHandle) throws InterruptedException { + if (closed) { + throw new IllegalStateException("BatchContext is closed"); + } + + TaskHandle existing = wip.get(taskHandle.id()); + if (existing != null) { + throw new DuplicateTaskException(taskHandle, existing); + } + + queue.put(taskHandle); + return taskHandle; + } + + private final class Send implements Runnable { + + @Override + public void run() { + try { + trySend(); + } finally { + workers.countDown(); + } + } + + /** + * trySend consumes {@link #queue} tasks and sends them in batches until it + * encounters a {@link TaskHandle#POISON} or is otherwise interrupted. + */ + private void trySend() { + try { + awaitCanPrepareNext(); + + while (!Thread.currentThread().isInterrupted()) { + if (batch.isFull()) { + System.out.println("==[send batch]==>"); + send(); + } + + TaskHandle task = queue.take(); + + if (task == TaskHandle.POISON) { + System.out.println("took POISON"); + drain(); + return; + } + + Data data = task.data(); + batch.add(data); + + TaskHandle existing = wip.put(task.id(), task); + assert existing == null : "duplicate tasks in progress, id=" + existing.id(); + } + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + } catch (Exception e) { + onEvent(new Event.ClientError(e)); + return; + } + } + + /** + * Send the current portion of batch items. After this method returns, the batch + * is guaranteed to have space for at least one the next item (not full). + */ + private void send() throws InterruptedException { + // Continue flushing until we get the batch to not a "not full" state. + // This is to account for the backlog, which might re-fill the batch + // after .clear(). + while (batch.isFull()) { + flush(); + } + assert !batch.isFull() : "batch is full after send"; + } + + /** + * Send all remaining items in the batch. After this method returns, the batch + * is guaranteed to be empty. + */ + private void drain() throws InterruptedException { + // To correctly drain the batch, we flush repeatedly + // until the batch becomes empty, as clearing a batch + // after an ACK might re-populate it from its internal backlog. + while (!batch.isEmpty()) { + flush(); + } + assert batch.isEmpty() : "batch not empty after drain"; + } + + private void flush() throws InterruptedException { + awaitCanSend(); + messages.onNext(batch.prepare()); + setState(IN_FLIGHT); + + // When we get into OOM / ServerShuttingDown state, then we can be certain that + // there isn't any reason to keep waiting for the ACKs. However, we should not + // exit without either taking a poison pill from the queue, + // or being interrupted, as this risks blocking the producer (main) thread. + awaitCanPrepareNext(); + } + + /** Block until the current state allows {@link State#canSend}. */ + private void awaitCanSend() throws InterruptedException { + lock.lock(); + try { + while (!state.canSend()) { + stateChanged.await(); + } + } finally { + lock.unlock(); + } + } + + /** + * Block until the current state allows {@link State#canPrepareNext}. + * + *

+ * Depending on the BatchContext lifecycle, the semantics of + * "await can prepare next" can be one of "message is ACK'ed" + * "the stream has started", or, more generally, + * "it is safe to take a next item from the queue and add it to the batch". + */ + private void awaitCanPrepareNext() throws InterruptedException { + lock.lock(); + try { + while (!state.canPrepareNext()) { + stateChanged.await(); + } + } finally { + lock.unlock(); + } + } + } + + private final class Recv implements StreamObserver { + + @Override + public void onNext(Event event) { + onEvent(event); + } + + /** + * EOF for the server-side stream. + * By the time this is called, the client-side of the stream had been closed + * and the "send" thread has either exited or is on its way there. + */ + @Override + public void onCompleted() { + try { + onEvent(Event.EOF); + } finally { + workers.countDown(); + } + } + + /** An exception occurred either on our end or in the channel internals. */ + @Override + public void onError(Throwable t) { + try { + onEvent(Event.StreamHangup.fromThrowable(t)); + } finally { + workers.countDown(); + } + } + } + + final State AWAIT_STARTED = new BaseState("AWAIT_STARTED", BaseState.Action.PREPARE_NEXT) { + @Override + public void onEvent(Event event) { + if (requireNonNull(event, "event is null") == Event.STARTED) { + setState(ACTIVE); + } else { + super.onEvent(event); + } + } + }; + final State ACTIVE = new BaseState("ACTIVE", BaseState.Action.PREPARE_NEXT, BaseState.Action.SEND); + final State IN_FLIGHT = new BaseState("IN_FLIGHT") { + @Override + public void onEvent(Event event) { + requireNonNull(event, "event is null"); + + if (event instanceof Event.Acks acks) { + Collection removed = batch.clear(); + if (!acks.acked().containsAll(removed)) { + throw ProtocolViolationException.incompleteAcks(List.copyOf(removed)); + } + acks.acked().forEach(id -> { + TaskHandle task = wip.get(id); + if (task != null) { + task.setAcked(); + } + }); + setState(ACTIVE); + } else if (event instanceof Event.Oom oom) { + setState(new Oom(oom.delaySeconds())); + } else { + super.onEvent(event); + } + } + }; + + private class BaseState implements State { + /** State's display name for logging. */ + private final String name; + /** Actions permitted in this state. */ + private final EnumSet permitted; + + enum Action { + /** + * Thy system is allowed to accept new items from the user + * and populate the next batch. + */ + PREPARE_NEXT, + + /** The system is allowed to send the next batch once it's ready. */ + SEND; + } + + /** + * @param name Display name. + * @param permitted Actions permitted in this state. + */ + protected BaseState(String name, Action... permitted) { + this.name = name; + this.permitted = requireNonNull(permitted, "actions is null").length == 0 + ? EnumSet.noneOf(Action.class) + : EnumSet.copyOf(Arrays.asList(permitted)); + } + + @Override + public void onEnter(State prev) { + } + + @Override + public boolean canSend() { + return permitted.contains(Action.SEND); + } + + @Override + public boolean canPrepareNext() { + return permitted.contains(Action.PREPARE_NEXT); + } + + /** + * Handle events which may arrive at any moment without violating the protocol. + * + *

    + *
  • {@link Event.Results} -- update tasks in {@link #wip} and remove them. + *
  • {@link Event.Backoff} -- adjust batch size. + *
  • {@link Event#SHUTTING_DOWN} -- transition into + * {@link ServerShuttingDown}. + *
  • {@link Event.StreamHangup -- transition into {@link Reconnecting} state. + *
  • {@link Event.ClientError -- shutdown the service immediately. + *
+ * + * @throws ProtocolViolationException If event cannot be handled in this state. + * @see BatchContext#shutdownNow + */ + @Override + public void onEvent(Event event) { + requireNonNull(event, "event is null"); + + if (event instanceof Event.Results results) { + onResults(results); + } else if (event instanceof Event.Backoff backoff) { + onBackoff(backoff); + } else if (event == Event.SHUTTING_DOWN) { + onShuttingDown(); + } else if (event instanceof Event.StreamHangup || event == Event.EOF) { + onStreamClosed(event); + } else if (event instanceof Event.ClientError error) { + onClientError(error); + } else { + throw ProtocolViolationException.illegalStateTransition(this, event); + } + } + + private final void onResults(Event.Results results) { + results.successful().forEach(id -> wip.remove(id).setSuccess()); + results.errors().forEach((id, error) -> wip.remove(id).setError(error)); + } + + private final void onBackoff(Event.Backoff backoff) { + System.out.print("========== BACKOFF =============="); + System.out.print(backoff.maxSize()); + System.out.print("================================="); + batch.setMaxSize(backoff.maxSize()); + } + + private final void onShuttingDown() { + setState(new ServerShuttingDown(this)); + } + + private final void onStreamClosed(Event event) { + if (event instanceof Event.StreamHangup hangup) { + hangup.exception().printStackTrace(); + } + if (!send.isDone()) { + setState(new Reconnecting(MAX_RECONNECT_RETRIES)); + } + } + + private final void onClientError(Event.ClientError error) { + shutdownNow(error.exception()); + } + + @Override + public String toString() { + return name; + } + } + + /** + * Oom waits for {@link Event#SHUTTING_DOWN} up to a specified amount of time, + * after which it will force stream termiation by imitating server shutdown. + */ + private final class Oom extends BaseState { + private final long delaySeconds; + private ScheduledFuture shutdown; + + private Oom(long delaySeconds) { + super("OOM"); + this.delaySeconds = delaySeconds; + } + + @Override + public void onEnter(State prev) { + shutdown = scheduledExec.schedule(this::initiateShutdown, delaySeconds, TimeUnit.SECONDS); + } + + /** Imitate server shutdown sequence. */ + private void initiateShutdown() { + // We cannot route event handling via normal BatchContext#onEvent, because + // it delegates to the current state, which is Oom. If Oom#onEvent were to + // receive an Event.SHUTTING_DOWN, it would cancel this execution of this + // very sequence. Instead, we delegate to our parent BaseState which normally + // handles these events. + if (Thread.currentThread().isInterrupted()) { + super.onEvent(Event.SHUTTING_DOWN); + } + if (Thread.currentThread().isInterrupted()) { + super.onEvent(Event.EOF); + } + } + + @Override + public void onEvent(Event event) { + requireNonNull(event, "event"); + if (event == Event.SHUTTING_DOWN || + event instanceof StreamHangup || + event instanceof ClientError) { + shutdown.cancel(true); + try { + shutdown.get(); + } catch (CancellationException ignored) { + } catch (InterruptedException ignored) { + // Recv is running on a thread from gRPC's internal thread pool, + // so, while onEvent allows InterruptedException to stay responsive, + // in practice this thread will only be interrupted by the thread pool, + // which already knows it's being shut down. + } catch (ExecutionException ex) { + onEvent(new Event.ClientError(ex)); + } + } + super.onEvent(event); + } + } + + /** + * ServerShuttingDown allows preparing the next batch + * unless the server's OOM'ed on the previous one. + * Once set, the state will shutdown {@link BatchContext#sendExec} + * to instruct the "send" thread to close our part of the stream. + */ + private final class ServerShuttingDown extends BaseState { + private final boolean canPrepareNext; + + private ServerShuttingDown(State previous) { + super("SERVER_SHUTTING_DOWN"); + this.canPrepareNext = requireNonNull(previous, "previous is null").getClass() != Oom.class; + } + + @Override + public boolean canPrepareNext() { + return canPrepareNext; + } + + @Override + public boolean canSend() { + return false; + } + + @Override + public void onEnter(State prev) { + send.cancel(true); + } + } + + /** + * Reconnecting state is entererd either by the server finishing a shutdown + * and closing it's end of the stream or an unexpected stream hangup. + * + * @see Recv#onCompleted graceful server shutdown + * @see Recv#onError stream hangup + */ + private final class Reconnecting extends BaseState { + private final int maxRetries; + private int retries = 0; + + private Reconnecting(int maxRetries) { + super("RECONNECTING", Action.PREPARE_NEXT); + this.maxRetries = maxRetries; + } + + @Override + public void onEnter(State prev) { + // The reconnected state is re-set every time the stream restarts. + // This ensures that onEnter hook is only called the first + // time we enter Reconnecting state. + if (prev == this) { + return; + } + + send.cancel(true); + + if (prev.getClass() != ServerShuttingDown.class) { + // This is NOT an orderly shutdown, we're reconnecting after a stream hangup. + // Assume all WIP items have been lost and re-submit everything. + // All items in the batch are contained in WIP, so it is safe to discard the + // batch entirely and re-populate from WIP. + while (!batch.isEmpty()) { + batch.clear(); + } + + // Unlike during normal operation, we will not stop when batch.isFull(). + // Batch#add guarantees that data will not be discarded in the event of + // an overflow -- all extra items are placed into the backlog, which is + // unbounded. + wip.values().forEach(task -> batch.add(task.data())); + } + + reconnectNow(); + } + + @Override + public void onEvent(Event event) { + assert retries <= maxRetries : "maxRetries exceeded"; + + if (event == Event.STARTED) { + setState(ACTIVE); + } else if (event instanceof Event.StreamHangup) { + if (retries == maxRetries) { + onEvent(new Event.ClientError(new IOException("Server unavailable"))); + } else { + reconnectAfter(1 * 2 ^ retries); + } + } + + assert retries <= maxRetries : "maxRetries exceeded"; + } + + /** Reconnect with no delay. */ + private void reconnectNow() { + reconnectAfter(0); + } + + /** + * Schedule a task to {@link #reconnect} after a delay. + * + * @param delaySeconds Delay in seconds. + * + * @apiNote The task is scheduled on {@link #scheduledExec} even if + * {@code delaySeconds == 0} to avoid blocking gRPC worker thread, + * where the {@link BatchContext#onEvent} callback runs. + */ + private void reconnectAfter(long delaySeconds) { + retries++; + + scheduledExec.schedule(() -> { + try { + reconnect(this); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (ExecutionException e) { + onEvent(new Event.ClientError(e)); + } + }, delaySeconds, TimeUnit.SECONDS); + } + } + + // -------------------------------------------------------------------------- + + private final ScheduledExecutorService reconnectExec = Executors.newScheduledThreadPool(1); + + void scheduleReconnect(int reconnectIntervalSeconds) { + reconnectExec.scheduleWithFixedDelay(() -> { + if (Thread.currentThread().isInterrupted()) { + onEvent(Event.SHUTTING_DOWN); + } + if (Thread.currentThread().isInterrupted()) { + onEvent(Event.EOF); + } + + // We want to count down from the moment we re-opened the stream, + // not from the moment we initialited the sequence. + lock.lock(); + try { + while (state != ACTIVE) { + stateChanged.await(); + } + } catch (InterruptedException ignored) { + // Let the process exit normally. + } finally { + lock.unlock(); + } + }, reconnectIntervalSeconds, reconnectIntervalSeconds, TimeUnit.SECONDS); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java new file mode 100644 index 000000000..dba9e79ad --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java @@ -0,0 +1,94 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import javax.annotation.concurrent.Immutable; + +import com.google.protobuf.GeneratedMessage; +import com.google.protobuf.GeneratedMessageV3; + +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.data.ObjectReference; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; + +@Immutable +@SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3 +class Data implements Message { + + /** + * Raw input value, as passed by the user. + */ + private final Object raw; + + /** + * Task ID. Depending on the underlying object, this will either be + * {@link WeaviateObject#uuid} or {@link ObjectReference#beacon}. + * + * Since UUIDs and beacons cannot clash, ID does not encode any + * information about the underlying data type. + */ + private final String id; + + /** + * Serialized representation of the {@link #raw}. This valus is immutable + * for the entire lifecycle of the handle. + */ + private final GeneratedMessage.ExtendableMessage message; + + /** Estimated size of the {@link #message} when serialized. */ + private final int sizeBytes; + + enum Type { + OBJECT(WeaviateProtoBatch.BatchStreamRequest.Data.OBJECTS_FIELD_NUMBER), + REFERENCE(WeaviateProtoBatch.BatchStreamRequest.Data.REFERENCES_FIELD_NUMBER); + + private final int fieldNumber; + + private Type(int fieldNumber) { + this.fieldNumber = fieldNumber; + } + + public int fieldNumber() { + return fieldNumber; + } + } + + private Data(Object raw, String id, GeneratedMessage.ExtendableMessage message, int sizeBytes) { + this.raw = requireNonNull(raw, "raw is null"); + this.id = requireNonNull(id, "id is null"); + this.message = requireNonNull(message, "message is null"); + + assert sizeBytes >= 0; + this.sizeBytes = sizeBytes; + } + + Data(Object raw, String id, GeneratedMessage.ExtendableMessage message, + Type type) { + this(raw, id, message, MessageSizeUtil.ofDataField(message, type)); + } + + String id() { + return id; + } + + /** Serialized data size in bytes. */ + int sizeBytes() { + return sizeBytes; + } + + @Override + public void appendTo(WeaviateProtoBatch.BatchStreamRequest.Builder builder) { + WeaviateProtoBatch.BatchStreamRequest.Data.Builder data = requireNonNull(builder, "builder is null") + .getDataBuilder(); + if (message instanceof WeaviateProtoBatch.BatchObject object) { + data.getObjectsBuilder().addValues(object); + } else if (message instanceof WeaviateProtoBatch.BatchReference ref) { + data.getReferencesBuilder().addValues(ref); + } + } + + @Override + public String toString() { + return "%s (%s)".formatted(raw.getClass().getSimpleName(), id); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java new file mode 100644 index 000000000..5160bd27c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java @@ -0,0 +1,16 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import io.weaviate.client6.v1.api.WeaviateException; + +/** + * DataTooBigException is thrown when a single object exceeds + * the maximum size of a gRPC message. + */ +public class DataTooBigException extends WeaviateException { + DataTooBigException(Data data, long maxSizeBytes) { + super("%s with size=%dB exceeds maximum message size %dB".formatted( + requireNonNull(data, "data is null"), data.sizeBytes(), maxSizeBytes)); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/DuplicateTaskException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DuplicateTaskException.java new file mode 100644 index 000000000..f0949a6c0 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DuplicateTaskException.java @@ -0,0 +1,23 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import io.weaviate.client6.v1.api.WeaviateException; + +/** + * DuplicateTaskException is thrown if task is submitted to the batch + * while another task with the same ID is in progress. + */ +public class DuplicateTaskException extends WeaviateException { + private final TaskHandle existing; + + DuplicateTaskException(TaskHandle duplicate, TaskHandle existing) { + super("%s cannot be added to the batch while another task with the same ID is in progress"); + this.existing = existing; + } + + /** + * Get the currently in-progress handle that's a duplicate of the one submitted. + */ + public TaskHandle getExisting() { + return existing; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java new file mode 100644 index 000000000..ba3f6b208 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java @@ -0,0 +1,162 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import io.grpc.Status; +import io.weaviate.client6.v1.api.collections.batch.Event.Acks; +import io.weaviate.client6.v1.api.collections.batch.Event.Backoff; +import io.weaviate.client6.v1.api.collections.batch.Event.ClientError; +import io.weaviate.client6.v1.api.collections.batch.Event.Oom; +import io.weaviate.client6.v1.api.collections.batch.Event.Results; +import io.weaviate.client6.v1.api.collections.batch.Event.Started; +import io.weaviate.client6.v1.api.collections.batch.Event.StreamHangup; +import io.weaviate.client6.v1.api.collections.batch.Event.TerminationEvent; + +sealed interface Event + permits Started, Acks, Results, Backoff, Oom, TerminationEvent, StreamHangup, ClientError { + + final static Event STARTED = new Started(); + final static Event SHUTTING_DOWN = TerminationEvent.SHUTTING_DOWN; + final static Event EOF = TerminationEvent.EOF; + + /** + * The server has acknowledged our Start message and is ready to receive data. + */ + record Started() implements Event { + } + + /** + * The server has added items from the previous message to its internal + * work queue, client MAY send the next batch. + * + *

+ * The protocol guarantess that {@link Acks} will contain IDs for all + * items sent in the previous batch. + */ + record Acks(Collection acked) implements Event { + public Acks { + acked = List.copyOf(requireNonNull(acked, "acked is null")); + } + + @Override + public String toString() { + return "Acks"; + } + } + + /** + * Results for the insertion of a previous batches. + * + *

+ * We assume that the server may return partial results, or return + * results out of the order of inserting messages. + */ + record Results(Collection successful, Map errors) implements Event { + public Results { + successful = List.copyOf(requireNonNull(successful, "successful is null")); + errors = Map.copyOf(requireNonNull(errors, "errors is null")); + } + + @Override + public String toString() { + return "Results"; + } + } + + /** + * Backoff communicates the optimal batch size (number of objects) + * with respect to the current load on the server. + * + *

+ * Backoff is an instruction, not a recommendation. + * On receiving this message, the client must ensure that + * all messages it produces, including the one being prepared, + * do not exceed the size limit indicated by {@link #maxSize} + * until the server sends another Backoff message. The limit + * MUST also be respected after a {@link BatchContext#reconnect}. + * + *

+ * The client MAY use the latest {@link #maxSize} as the default + * message limit in a new {@link BatchContext}, but is not required to. + */ + record Backoff(int maxSize) implements Event { + + @Override + public String toString() { + return "Backoff"; + } + } + + /** + * Out-Of-Memory. + * + *

+ * Items sent in the previous request cannot be accepted, + * as inserting them may exhaust server's available disk space. + * On receiving this message, the client MUST stop producing + * messages immediately and await {@link #SHUTTING_DOWN} event. + * + *

+ * Oom is the sibling of {@link Acks} with the opposite effect. + * The protocol guarantees that the server will respond with either of + * the two, but never both. + */ + record Oom(int delaySeconds) implements Event { + } + + /** Events that are part of the server's graceful shutdown strategy. */ + enum TerminationEvent implements Event { + /** + * Server shutdown in progress. + * + *

+ * The server began the process of gracefull shutdown, due to a + * scale-up event (if it previously reported {@link #OOM}) or + * some other external event. + * On receiving this message, the client MUST stop producing + * messages immediately, close it's side of the stream, and + * continue readings server's messages until {@link #EOF}. + */ + SHUTTING_DOWN, + + /** + * Stream EOF. + * + *

+ * The server has will not receive any messages. If the client + * has more data to send, it SHOULD re-connect to another instance + * by re-opening the stream and continue processing the batch. + * If the client has previously sent {@link Message#STOP}, it can + * safely exit. + */ + EOF; + } + + /** + * StreamHangup means the RPC is "dead": the stream is closed + * and using it will result in an {@link IllegalStateException}. + */ + record StreamHangup(Exception exception) implements Event { + static StreamHangup fromThrowable(Throwable t) { + Status status = Status.fromThrowable(t); + return new StreamHangup(status.asException()); + } + } + + /** + * ClientError means a client-side exception has happened, + * and is meant primarily for the "send" thread to propagate + * any exception it might catch. + * + *

+ * This MUST be treated as an irrecoverable condition, because + * it is likely caused by an internal issue (an NPE) or a bad + * input ({@link DataTooBigException}). + */ + record ClientError(Exception exception) implements Event { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java new file mode 100644 index 000000000..9d387d1b1 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java @@ -0,0 +1,32 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.Optional; + +import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; + +@FunctionalInterface +interface Message { + void appendTo(WeaviateProtoBatch.BatchStreamRequest.Builder builder); + + /** Create a Start message. */ + static Message start(Optional consistencyLevel) { + requireNonNull(consistencyLevel, "consistencyLevel is null"); + + final WeaviateProtoBatch.BatchStreamRequest.Start.Builder start = WeaviateProtoBatch.BatchStreamRequest.Start + .newBuilder(); + consistencyLevel.ifPresent(value -> value.appendTo(start)); + return builder -> builder.setStart(start); + } + + /** Create a Stop message. */ + static Message stop() { + return STOP; + } + + static final Message STOP = builder -> builder + .setStop(WeaviateProtoBatch.BatchStreamRequest.Stop.getDefaultInstance()); + +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/MessageSizeUtil.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/MessageSizeUtil.java new file mode 100644 index 000000000..f3c934232 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/MessageSizeUtil.java @@ -0,0 +1,44 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import com.google.protobuf.CodedOutputStream; +import com.google.protobuf.GeneratedMessage; +import com.google.protobuf.GeneratedMessageV3; + +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; + +final class MessageSizeUtil { + private static int DATA_TAG_SIZE = CodedOutputStream + .computeTagSize(WeaviateProtoBatch.BatchStreamRequest.DATA_FIELD_NUMBER); + + private MessageSizeUtil() { + } + + /** + * Adjust batch byte-size limit to account for the + * {@link WeaviateProtoBatch.BatchStreamRequest.Data} container. + * + *

+ * A protobuf field has layout {@code [tag][lenght(payload)][payload]}, + * so to estimate the batch size correctly we must account for "tag" + * and "length", not just the raw payload. + */ + static long maxSizeBytes(long maxSizeBytes) { + if (maxSizeBytes <= DATA_TAG_SIZE) { + throw new IllegalArgumentException("Maximum batch size must be at least %dB".formatted(DATA_TAG_SIZE)); + } + return maxSizeBytes - DATA_TAG_SIZE; + } + + /** + * Calculate the size of a serialized + * {@link WeaviateProtoBatch.BatchStreamRequest.Data} field. + */ + @SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3 + static int ofDataField(GeneratedMessage.ExtendableMessage message, Data.Type type) { + requireNonNull(type, "type is null"); + requireNonNull(message, "message is null"); + return CodedOutputStream.computeMessageSize(type.fieldNumber(), message); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/ProtocolViolationException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/ProtocolViolationException.java new file mode 100644 index 000000000..d97b4839c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/ProtocolViolationException.java @@ -0,0 +1,48 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.List; + +import io.weaviate.client6.v1.api.WeaviateException; + +/** + * ProtocolViolationException describes unexpected server behavior in violation + * of the SSB protocol. + * + *

+ * This exception cannot be handled in a meaningful way and should be reported + * to the upstream Weaviate + * project. + */ +public class ProtocolViolationException extends WeaviateException { + ProtocolViolationException(String message) { + super(message); + } + + /** + * Protocol violated because an event arrived while the client is in a state + * which doesn't expect to handle this event. + * + * @param current Current {@link BatchContext} state. + * @param event Server-side event. + * @return ProtocolViolationException with a formatted message. + */ + static ProtocolViolationException illegalStateTransition(State current, Event event) { + return new ProtocolViolationException("%s arrived in %s state".formatted(event, current)); + } + + /** + * Protocol violated because some tasks from the previous Data message + * are not present in the Acks message. + * + * @param remaining IDs of the tasks that weren't ack'ed. MUST be non-empty. + * @return ProtocolViolationException with a formatted message. + */ + static ProtocolViolationException incompleteAcks(List remaining) { + requireNonNull(remaining, "remaining is null"); + return new ProtocolViolationException("IDs from previous Data message missing in Acks: '%s', ... (%d more)" + .formatted(remaining.get(0), remaining.size() - 1)); + + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java new file mode 100644 index 000000000..a4ba41310 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java @@ -0,0 +1,42 @@ +package io.weaviate.client6.v1.api.collections.batch; + +interface State { + /** + * canSend returns a boolean indicating if sending + * an "insert" message is allowed in this state. + */ + boolean canSend(); + + /** + * canPrepareNext returns a boolean indicating if accepting + * more items into the batch is allowed in this state. + */ + boolean canPrepareNext(); + + /** + * Lifecycle hook that's called after the state is set. + * + *

+ *
+ * This hook MUST be called exactly once. + *
+ * The next state MUST NOT be set until onEnter returns. + * + * @param prev Previous state or null. + */ + void onEnter(State prev); + + /** + * onEvent handles incoming events; these can be generated by the server + * or by a different part of the program -- the {@link State} MUST NOT + * make any assumptions about the event's origin. + * + *

+ * How the event is handled is up to the concrete implementation. + * It may modify {@link BatchContext} internal state, via one of it's + * package-private methods, including transitioning the context to a + * different state via {@link BatchContext#setState(State)}, or start + * a separate process, e.g. the OOM timer. + */ + void onEvent(Event event); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java new file mode 100644 index 000000000..c50eec6ee --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java @@ -0,0 +1,13 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import io.grpc.stub.StreamObserver; + +/** + * @param the type of the object sent down the stream. + * @param the type of the object received from the stream. + */ +@FunctionalInterface +interface StreamFactory { + /** Create a new stream for the send-recv observer pair. */ + StreamObserver createStream(StreamObserver recv); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java new file mode 100644 index 000000000..aad9c8a71 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java @@ -0,0 +1,176 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import javax.annotation.concurrent.ThreadSafe; + +import com.google.protobuf.GeneratedMessage; +import com.google.protobuf.GeneratedMessageV3; + +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.data.BatchReference; + +@ThreadSafe +@SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3 +public final class TaskHandle { + static final TaskHandle POISON = new TaskHandle(); + + /** + * Input value as passed by the user. + * + *

+ * Changes in the {@link #raw}'s underlying value will not be reflected + * in the {@link TaskHandle} (e.g. the serialized version is not updated), + * so users SHOULD treat items passed to and retrieved from {@link TaskHandle} + * as effectively ummodifiable. + */ + private final Data data; + + /** Flag indicatig the task has been ack'ed. */ + private final CompletableFuture acked = new CompletableFuture<>(); + + public final record Result(Optional error) { + public Result { + requireNonNull(error, "error is null"); + } + } + + /** + * Task result completes when the client receives {@link Event.Results} + * containing this handle's {@link #id}. + */ + private final CompletableFuture result = new CompletableFuture<>(); + + /** The number of times this task has been retried. */ + private final int retries; + + /** Task creation timestamp. */ + private final Instant createdAt = Instant.now(); + + private TaskHandle(Data data, int retries) { + this.data = requireNonNull(data, "data is null"); + + assert retries >= 0 : "negative retries"; + this.retries = retries; + } + + /** Constructor for {@link WeaviateObject}. */ + TaskHandle(WeaviateObject object, GeneratedMessage.ExtendableMessage data) { + this(new Data(object, object.uuid(), data, Data.Type.OBJECT), 0); + } + + /** Constructor for {@link BatchReference}. */ + TaskHandle(BatchReference reference, GeneratedMessage.ExtendableMessage data) { + this(new Data(reference, reference.target().beacon(), data, Data.Type.REFERENCE), 0); + } + + /** + * Poison pill constructor. + * + *

+ * A handle created with this constructor should not be + * used for anything other that direct comparison using {@code ==} operator; + * calling any method on a poison pill is likely to result in a + * {@link NullPointerException} being thrown. + */ + private TaskHandle() { + this.data = null; + this.retries = 0; + } + + /** + * Creates a new task containing the same data as this task and {@link retries} + * counter incremented by 1. The {@link acked} and {@link result} futures + * are not copied to the returned task. + * + * @return Task handle. + */ + TaskHandle retry() { + return new TaskHandle(data, retries + 1); + } + + String id() { + return data.id(); + } + + Data data() { + return data; + } + + /** Set the {@link #acked} flag. */ + void setAcked() { + acked.complete(null); + } + + /** + * Mark the task successful. This status cannot be changed, so calling + * {@link #setError} afterwards will have no effect. + */ + void setSuccess() { + setResult(new Result(Optional.empty())); + } + + /** + * Mark the task failed. This status cannot be changed, so calling + * {@link #setSuccess} afterwards will have no effect. + * + * @param error Error message. Null values are tolerated, but are only expected + * to occurr due to a server's mistake. + * Do not use {@code setError(null)} if the server reports success + * status for the task; prefer {@link #setSuccess} in that case. + */ + void setError(String error) { + setResult(new Result(Optional.ofNullable(error))); + } + + /** + * Set result for this task. + * + * @throws IllegalStateException if the task has not been ack'ed. + */ + private void setResult(Result result) { + if (!acked.isDone()) { + // TODO(dyma): can this happen due to us? + throw new IllegalStateException("Result can only be set for an ack'ed task"); + } + this.result.complete(result); + } + + /** + * Check if the task has been accepted. + * + * @return A future which completes when the server has accepted the task. + */ + public CompletableFuture isAcked() { + return acked; + } + + /** + * Retrieve the result for this task. + * + * @return A future which completes when the server + * has reported the result for this task. + */ + public CompletableFuture result() { + return result; + } + + /** + * Number of times this task has been retried. Since {@link TaskHandle} is + * immutable, this value does not change, but retrying a task via + * {@link BatchContext#retry} is reflected in the returned handle's + * {@link #timesRetried}. + */ + public int timesRetried() { + return retries; + } + + @Override + public String toString() { + return "TaskHandle".formatted(id(), retries, createdAt); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java new file mode 100644 index 000000000..2e6344ada --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java @@ -0,0 +1,147 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import io.grpc.stub.StreamObserver; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply; + +/** + * TranslatingStreamFactory is an adaptor for the + * {@link WeaviateGrpc.WeaviateStub#batchStream} factory. The returned stream + * translates client-side messages into protobuf requests and server-side + * replies into events. + * + * @see Message + * @see Event + */ +class TranslatingStreamFactory implements StreamFactory { + private final StreamFactory protoFactory; + + TranslatingStreamFactory( + StreamFactory protoFactory) { + this.protoFactory = requireNonNull(protoFactory, "protoFactory is null"); + } + + @Override + public StreamObserver createStream(StreamObserver recv) { + return new Messenger(protoFactory.createStream(new Eventer(recv))); + } + + /** + * DelegatingStreamObserver delegates {@link #onCompleted} and {@link #onError} + * to another observer and translates the messages in {@link #onNext}. + * + * @param the type of the incoming message. + * @param the type of the message handed to the delegate. + */ + private abstract class DelegatingStreamObserver implements StreamObserver { + protected final StreamObserver delegate; + + protected DelegatingStreamObserver(StreamObserver delegate) { + this.delegate = delegate; + } + + @Override + public void onCompleted() { + delegate.onCompleted(); + } + + @Override + public void onError(Throwable t) { + delegate.onError(t); + } + } + + /** + * Messeger translates client's messages into batch stream requests. + * + * @see Message + */ + private final class Messenger extends DelegatingStreamObserver { + private Messenger(StreamObserver delegate) { + super(delegate); + } + + @Override + public void onNext(Message message) { + WeaviateProtoBatch.BatchStreamRequest.Builder builder = WeaviateProtoBatch.BatchStreamRequest.newBuilder(); + message.appendTo(builder); + delegate.onNext(builder.build()); + } + } + + /** + * Eventer translates server replies into events. + * + * @see Event + */ + private final class Eventer extends DelegatingStreamObserver { + private Eventer(StreamObserver delegate) { + super(delegate); + } + + @Override + public void onNext(BatchStreamReply reply) { + Event event = null; + switch (reply.getMessageCase()) { + case STARTED: + event = Event.STARTED; + break; + case SHUTTING_DOWN: + event = Event.SHUTTING_DOWN; + break; + case SHUTDOWN: + event = Event.EOF; + break; + case OUT_OF_MEMORY: + // TODO(dyma): read this value from the message + event = new Event.Oom(300); + break; + case BACKOFF: + event = new Event.Backoff(reply.getBackoff().getBatchSize()); + break; + case ACKS: + Stream uuids = reply.getAcks().getUuidsList().stream(); + Stream beacons = reply.getAcks().getBeaconsList().stream(); + event = new Event.Acks(Stream.concat(uuids, beacons).toList()); + break; + case RESULTS: + List successful = reply.getResults().getSuccessesList().stream() + .map(detail -> { + if (detail.hasUuid()) { + return detail.getUuid(); + } else if (detail.hasBeacon()) { + return detail.getBeacon(); + } + throw new IllegalArgumentException("Result has neither UUID nor a beacon"); + }) + .toList(); + + Map errors = reply.getResults().getErrorsList().stream() + .map(detail -> { + String error = requireNonNull(detail.getError(), "error is null"); + if (detail.hasUuid()) { + return Map.entry(detail.getUuid(), error); + } else if (detail.hasBeacon()) { + return Map.entry(detail.getBeacon(), error); + } + throw new IllegalArgumentException("Result has neither UUID nor a beacon"); + }) + .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); + event = new Event.Results(successful, errors); + break; + case MESSAGE_NOT_SET: + throw new ProtocolViolationException("Message not set"); + } + + delegate.onNext(event); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java new file mode 100644 index 000000000..453862f5e --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java @@ -0,0 +1,59 @@ +package io.weaviate.client6.v1.api.collections.batch; + +import static java.util.Objects.requireNonNull; + +import java.util.OptionalInt; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.internal.TransportOptions; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public class WeaviateBatchClient { + private final CollectionHandleDefaults defaults; + private final CollectionDescriptor collectionDescriptor; + private final GrpcTransport grpcTransport; + + public WeaviateBatchClient( + GrpcTransport grpcTransport, + CollectionDescriptor collectionDescriptor, + CollectionHandleDefaults defaults) { + this.defaults = requireNonNull(defaults, "defaults is null"); + this.collectionDescriptor = requireNonNull(collectionDescriptor, "collectionDescriptor is null"); + this.grpcTransport = requireNonNull(grpcTransport, "grpcTransport is null"); + } + + /** Copy constructor with new defaults. */ + public WeaviateBatchClient(WeaviateBatchClient c, CollectionHandleDefaults defaults) { + this.defaults = requireNonNull(defaults, "defaults is null"); + this.collectionDescriptor = c.collectionDescriptor; + this.grpcTransport = c.grpcTransport; + } + + public BatchContext start() { + OptionalInt maxSizeBytes = grpcTransport.maxMessageSizeBytes(); + if (maxSizeBytes.isEmpty()) { + throw new IllegalStateException("Server must have grpcMaxMessageSize configured to use server-side batching"); + } + + StreamFactory streamFactory = new TranslatingStreamFactory(grpcTransport::createStream); + BatchContext context = new BatchContext<>( + streamFactory, + maxSizeBytes.getAsInt(), + collectionDescriptor, + defaults); + + if (isWeaviateCloudOnGoogleCloud(grpcTransport.host())) { + context.scheduleReconnect(GCP_RECONNECT_INTERVAL_SECONDS); + } + + context.start(); + return context; + } + + private static final int GCP_RECONNECT_INTERVAL_SECONDS = 160; + + private static boolean isWeaviateCloudOnGoogleCloud(String host) { + return TransportOptions.isWeaviateDomain(host) && TransportOptions.isGoogleCloudDomain(host); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java index 5237fed26..c27c13aa7 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java @@ -21,8 +21,8 @@ public static final Endpoint> endpoint( return SimpleEndpoint.noBody( request -> "GET", request -> "/schema/" + collection.collectionName() + "/shards", - request -> defaults.tenant() != null - ? Map.of("tenant", defaults.tenant()) + request -> defaults.tenant().isPresent() + ? Map.of("tenant", defaults.tenant().get()) : Collections.emptyMap(), (statusCode, response) -> (List) JSON.deserialize(response, TypeToken.getParameterized( List.class, Shard.class))); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java index a83652be5..503ec59d1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java @@ -1,5 +1,7 @@ package io.weaviate.client6.v1.api.collections.data; +import static java.util.Objects.requireNonNull; + import java.io.IOException; import java.util.Arrays; @@ -9,7 +11,14 @@ import io.weaviate.client6.v1.api.collections.WeaviateObject; -public record BatchReference(String fromCollection, String fromProperty, String fromUuid, ObjectReference reference) { +public record BatchReference(String fromCollection, String fromProperty, String fromUuid, ObjectReference target) { + + public BatchReference { + requireNonNull(fromCollection, "fromCollection is null"); + requireNonNull(fromProperty, "fromProperty is null"); + requireNonNull(fromUuid, "fromUuid is null"); + requireNonNull(target, "target is null"); + } public static BatchReference[] objects(WeaviateObject fromObject, String fromProperty, WeaviateObject... toObjects) { @@ -39,7 +48,7 @@ public void write(JsonWriter out, BatchReference value) throws IOException { out.value(ObjectReference.toBeacon(value.fromCollection, value.fromProperty, value.fromUuid)); out.name("to"); - out.value(ObjectReference.toBeacon(value.reference.collection(), value.reference.uuid())); + out.value(ObjectReference.toBeacon(value.target.collection(), value.target.uuid())); out.endObject(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java index 2fff8681a..e60133957 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java @@ -29,11 +29,9 @@ public static Rpc Rpc, WeaviateProtoBat request -> { var message = WeaviateProtoBatch.BatchObjectsRequest.newBuilder(); - var batch = request.objects.stream().map(obj -> { - var batchObject = WeaviateProtoBatch.BatchObject.newBuilder(); - buildObject(batchObject, obj, collection, defaults); - return batchObject.build(); - }).toList(); - + var batch = request.objects.stream() + .map(obj -> buildObject(obj, collection, defaults)) + .toList(); message.addAllObjects(batch); - if (defaults.consistencyLevel() != null) { - defaults.consistencyLevel().appendTo(message); + if (defaults.consistencyLevel().isPresent()) { + defaults.consistencyLevel().get().appendTo(message); } + var m = message.build(); + m.getSerializedSize(); return message.build(); }, response -> { @@ -92,10 +94,11 @@ public static Rpc, WeaviateProtoBat () -> WeaviateFutureStub::batchObjects); } - public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object, + public static WeaviateProtoBatch.BatchObject buildObject( WeaviateObject insert, CollectionDescriptor collection, CollectionHandleDefaults defaults) { + var object = WeaviateProtoBatch.BatchObject.newBuilder(); object.setCollection(collection.collectionName()); if (insert.uuid() != null) { @@ -121,9 +124,7 @@ public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object }).toList(); object.addAllVectors(vectors); } - if (defaults.tenant() != null) { - object.setTenant(defaults.tenant()); - } + defaults.tenant().ifPresent(object::setTenant); var singleRef = new ArrayList(); var multiRef = new ArrayList(); @@ -158,6 +159,7 @@ public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object properties.setNonRefProperties(nonRef); } object.setProperties(properties); + return object.build(); } @SuppressWarnings("unchecked") @@ -330,4 +332,20 @@ private static com.google.protobuf.Struct marshalStruct(Map prop }); return struct.build(); } + + public static WeaviateProtoBatch.BatchReference buildReference(BatchReference reference, Optional tenant) { + requireNonNull(reference, "reference is null"); + WeaviateProtoBatch.BatchReference.Builder builder = WeaviateProtoBatch.BatchReference.newBuilder(); + builder + .setName(reference.fromProperty()) + .setFromCollection(reference.fromCollection()) + .setFromUuid(reference.fromUuid()) + .setToUuid(reference.target().uuid()); + + if (reference.target().collection() != null) { + builder.setToCollection(reference.target().collection()); + } + tenant.ifPresent(t -> builder.setTenant(t)); + return builder.build(); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java index 8588eb760..3bfcdf52d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java @@ -27,14 +27,14 @@ public static final Endpoint, Wea return new SimpleEndpoint<>( request -> "POST", request -> "/objects/", - request -> defaults.consistencyLevel() != null - ? Map.of("consistency_level", defaults.consistencyLevel()) + request -> defaults.consistencyLevel().isPresent() + ? Map.of("consistency_level", defaults.consistencyLevel().get()) : Collections.emptyMap(), request -> JSON.serialize( new WeaviateObject<>( request.object.uuid(), collection.collectionName(), - defaults.tenant(), + defaults.tenant().get(), request.object.properties(), request.object.vectors(), request.object.createdAt(), diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java index bb6a0f27e..822c5f54d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java @@ -67,6 +67,10 @@ public static ObjectReference[] collection(String collection, String... uuids) { .toArray(ObjectReference[]::new); } + public String beacon() { + return toBeacon(collection, uuid); + } + public static String toBeacon(String collection, String uuid) { return toBeacon(collection, null, uuid); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java index 284688daa..cfb2a4667 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java @@ -4,6 +4,7 @@ import java.util.List; import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch; import io.weaviate.client6.v1.internal.json.JSON; import io.weaviate.client6.v1.internal.rest.Endpoint; import io.weaviate.client6.v1.internal.rest.SimpleEndpoint; @@ -32,4 +33,7 @@ public static final Endpoint }); } + public static WeaviateProtoBatch.BatchReference buildReference(ObjectReference reference) { + return null; + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java index 13a1afacb..45419afff 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java @@ -26,14 +26,14 @@ static final Endpoint, Void> end return SimpleEndpoint.sideEffect( request -> "PUT", request -> "/objects/" + collection.collectionName() + "/" + request.object.uuid(), - request -> defaults.consistencyLevel() != null - ? Map.of("consistency_level", defaults.consistencyLevel()) + request -> defaults.consistencyLevel().isPresent() + ? Map.of("consistency_level", defaults.consistencyLevel().get()) : Collections.emptyMap(), request -> JSON.serialize( new WeaviateObject<>( request.object.uuid(), collection.collectionName(), - defaults.tenant(), + defaults.tenant().get(), request.object.properties(), request.object.vectors(), request.object.createdAt(), diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java index 6157a1cc8..65a1a66f1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java @@ -26,14 +26,14 @@ static final Endpoint, Void> endp return SimpleEndpoint.sideEffect( request -> "PATCH", request -> "/objects/" + collection.collectionName() + "/" + request.object.uuid(), - request -> defaults.consistencyLevel() != null - ? Map.of("consistency_level", defaults.consistencyLevel()) + request -> defaults.consistencyLevel().isPresent() + ? Map.of("consistency_level", defaults.consistencyLevel().get()) : Collections.emptyMap(), request -> JSON.serialize( new WeaviateObject<>( request.object.uuid(), collection.collectionName(), - defaults.tenant(), + defaults.tenant().get(), request.object.properties(), request.object.vectors(), request.object.createdAt(), diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java index c7497ec64..19613dff2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java @@ -35,7 +35,7 @@ public WeaviateDataClient( this.defaults = defaults; } - /** Copy constructor that updates the {@link #query} to use new defaults. */ + /** Copy constructor with new defaults. */ public WeaviateDataClient(WeaviateDataClient c, CollectionHandleDefaults defaults) { this.restTransport = c.restTransport; this.grpcTransport = c.grpcTransport; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java index 326609d2e..fc688cd41 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java @@ -43,6 +43,10 @@ public final void appendTo(WeaviateProtoBatch.BatchObjectsRequest.Builder req) { req.setConsistencyLevel(consistencyLevel); } + public final void appendTo(WeaviateProtoBatch.BatchStreamRequest.Start.Builder req) { + req.setConsistencyLevel(consistencyLevel); + } + @Override public String toString() { return queryParameter; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index 625dde30d..8bb0db91b 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -37,11 +37,9 @@ public static WeaviateProtoSearchGet.SearchRequest marshal( } request.operator.appendTo(message); - if (defaults.tenant() != null) { - message.setTenant(defaults.tenant()); - } - if (defaults.consistencyLevel() != null) { - defaults.consistencyLevel().appendTo(message); + defaults.tenant().ifPresent(message::setTenant); + if (defaults.consistencyLevel().isPresent()) { + defaults.consistencyLevel().get().appendTo(message); } if (request.groupBy != null) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java index b1bc7369e..751492591 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java @@ -3,7 +3,6 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.UUID; import java.util.stream.Stream; @@ -92,20 +91,9 @@ static WeaviateObject unmarshalWithReferences( (map, ref) -> { var refObjects = ref.getPropertiesList().stream() .map(property -> { - var reference = unmarshalWithReferences( + return (Reference) unmarshalWithReferences( property, property.getMetadata(), CollectionDescriptor.ofMap(property.getTargetCollection())); - return (Reference) new WeaviateObject<>( - reference.uuid(), - reference.collection(), - // TODO(dyma): we can get tenant from CollectionHandle - null, // tenant is not returned in the query - (Map) reference.properties(), - reference.vectors(), - reference.createdAt(), - reference.lastUpdatedAt(), - reference.queryMetadata(), - reference.references()); }) .toList(); diff --git a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java index 897bb28cd..06c0b6c15 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java @@ -57,4 +57,20 @@ public H headers() { public TrustManagerFactory trustManagerFactory() { return this.trustManagerFactory; } + + /** + * isWeaviateDomain returns true if the host matches weaviate.io, + * semi.technology, or weaviate.cloud domain. + */ + public static boolean isWeaviateDomain(String host) { + var lower = host.toLowerCase(); + return lower.contains("weaviate.io") || + lower.contains("semi.technology") || + lower.contains("weaviate.cloud"); + } + + public static boolean isGoogleCloudDomain(String host) { + var lower = host.toLowerCase(); + return lower.contains("gcp"); + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java index d12255d22..385808ffc 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java @@ -1,6 +1,10 @@ package io.weaviate.client6.v1.internal.grpc; +import static java.util.Objects.requireNonNull; + +import java.util.OptionalInt; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; @@ -16,57 +20,57 @@ import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; import io.grpc.stub.AbstractStub; import io.grpc.stub.MetadataUtils; +import io.grpc.stub.StreamObserver; import io.weaviate.client6.v1.api.WeaviateApiException; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamRequest; public final class DefaultGrpcTransport implements GrpcTransport { + /** + * ListenableFuture callbacks are executed + * in the same thread they are called from. + */ + private static final Executor FUTURE_CALLBACK_EXECUTOR = Runnable::run; + + private final GrpcChannelOptions transportOptions; private final ManagedChannel channel; private final WeaviateBlockingStub blockingStub; private final WeaviateFutureStub futureStub; - private final GrpcChannelOptions transportOptions; - private TokenCallCredentials callCredentials; public DefaultGrpcTransport(GrpcChannelOptions transportOptions) { - this.transportOptions = transportOptions; - this.channel = buildChannel(transportOptions); - - var blockingStub = WeaviateGrpc.newBlockingStub(channel) - .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); - - var futureStub = WeaviateGrpc.newFutureStub(channel) - .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); - - if (transportOptions.maxMessageSize() != null) { - var max = transportOptions.maxMessageSize(); - blockingStub = blockingStub.withMaxInboundMessageSize(max).withMaxOutboundMessageSize(max); - futureStub = futureStub.withMaxInboundMessageSize(max).withMaxOutboundMessageSize(max); - } + requireNonNull(transportOptions, "transportOptions is null"); + this.transportOptions = transportOptions; if (transportOptions.tokenProvider() != null) { this.callCredentials = new TokenCallCredentials(transportOptions.tokenProvider()); - blockingStub = blockingStub.withCallCredentials(callCredentials); - futureStub = futureStub.withCallCredentials(callCredentials); } - this.blockingStub = blockingStub; - this.futureStub = futureStub; + this.channel = buildChannel(transportOptions); + this.blockingStub = configure(WeaviateGrpc.newBlockingStub(channel)); + this.futureStub = configure(WeaviateGrpc.newFutureStub(channel)); } private > StubT applyTimeout(StubT stub, Rpc rpc) { if (transportOptions.timeout() == null) { return stub; } - var timeout = rpc.isInsert() + int timeout = rpc.isInsert() ? transportOptions.timeout().insertSeconds() : transportOptions.timeout().querySeconds(); return stub.withDeadlineAfter(timeout, TimeUnit.SECONDS); } + @Override + public OptionalInt maxMessageSizeBytes() { + return transportOptions.maxMessageSize(); + } + @Override public ResponseT performRequest(RequestT request, Rpc rpc) { @@ -96,7 +100,9 @@ public CompletableFuture perf * reusing the thread in which the original future is completed. */ private static final CompletableFuture toCompletableFuture(ListenableFuture listenable) { - var completable = new CompletableFuture(); + requireNonNull(listenable, "listenable is null"); + + CompletableFuture completable = new CompletableFuture<>(); Futures.addCallback(listenable, new FutureCallback() { @Override @@ -113,13 +119,14 @@ public void onFailure(Throwable t) { completable.completeExceptionally(t); } - }, Runnable::run); + }, FUTURE_CALLBACK_EXECUTOR); return completable; } private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) { - var channel = NettyChannelBuilder.forAddress(transportOptions.host(), transportOptions.port()); + requireNonNull(transportOptions, "transportOptions is null"); + NettyChannelBuilder channel = NettyChannelBuilder.forAddress(transportOptions.host(), transportOptions.port()); if (transportOptions.isSecure()) { channel.useTransportSecurity(); } else { @@ -140,10 +147,29 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) } channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); - return channel.build(); } + @Override + public StreamObserver createStream(StreamObserver recv) { + return configure(WeaviateGrpc.newStub(channel)).batchStream(recv); + } + + /** Apply common configuration to a stub. */ + private > S configure(S stub) { + requireNonNull(stub, "stub is null"); + + stub = stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); + if (transportOptions.maxMessageSize().isPresent()) { + int max = transportOptions.maxMessageSize().getAsInt(); + stub = stub.withMaxInboundMessageSize(max).withMaxOutboundMessageSize(max); + } + if (callCredentials != null) { + stub = stub.withCallCredentials(callCredentials); + } + return stub; + } + @Override public void close() throws Exception { channel.shutdown(); @@ -151,4 +177,9 @@ public void close() throws Exception { callCredentials.close(); } } + + @Override + public String host() { + return transportOptions.host(); + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java index 5e4453d7f..96366cb5f 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java @@ -1,6 +1,7 @@ package io.weaviate.client6.v1.internal.grpc; import java.util.Map; +import java.util.OptionalInt; import javax.net.ssl.TrustManagerFactory; @@ -10,7 +11,7 @@ import io.weaviate.client6.v1.internal.TransportOptions; public class GrpcChannelOptions extends TransportOptions { - private final Integer maxMessageSize; + private final OptionalInt maxMessageSize; public GrpcChannelOptions(String scheme, String host, int port, Map headers, TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout) { @@ -18,17 +19,18 @@ public GrpcChannelOptions(String scheme, String host, int port, Map ResponseT performRequest(RequestT request, - Rpc rpc); + ResponseT performRequest(RequestT request, + Rpc rpc); + + CompletableFuture performRequestAsync(RequestT request, + Rpc rpc); + + /** + * Create stream for batch insertion. + * + * @apiNote Batch insertion is presently the only operation performed over a + * StreamStream connection, which is why we do not parametrize this + * method. + */ + StreamObserver createStream( + StreamObserver recv); + + String host(); - CompletableFuture performRequestAsync(RequestT request, - Rpc rpc); + /** + * Maximum inbound/outbound message size supported by the underlying channel. + */ + OptionalInt maxMessageSizeBytes(); } diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java index 4b420b3fd..95d7cd126 100644 --- a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java +++ b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java @@ -24,7 +24,7 @@ public class CollectionHandleDefaultsTest { /** All defaults are {@code null} if none were set. */ @Test public void test_defaults() { - Assertions.assertThat(HANDLE_NONE.consistencyLevel()).as("default ConsistencyLevel").isNull(); + Assertions.assertThat(HANDLE_NONE.consistencyLevel()).as("default ConsistencyLevel").isEmpty(); Assertions.assertThat(HANDLE_NONE.tenant()).as("default tenant").isNull(); } @@ -35,8 +35,8 @@ public void test_defaults() { @Test public void test_withConsistencyLevel() { var handle = HANDLE_NONE.withConsistencyLevel(ConsistencyLevel.QUORUM); - Assertions.assertThat(handle.consistencyLevel()).isEqualTo(ConsistencyLevel.QUORUM); - Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isNull(); + Assertions.assertThat(handle.consistencyLevel()).get().isEqualTo(ConsistencyLevel.QUORUM); + Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isEmpty(); } /** @@ -46,8 +46,8 @@ public void test_withConsistencyLevel() { @Test public void test_withConsistencyLevel_async() { var handle = HANDLE_NONE_ASYNC.withConsistencyLevel(ConsistencyLevel.QUORUM); - Assertions.assertThat(handle.consistencyLevel()).isEqualTo(ConsistencyLevel.QUORUM); - Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isNull(); + Assertions.assertThat(handle.consistencyLevel()).get().isEqualTo(ConsistencyLevel.QUORUM); + Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isEmpty(); } /** @@ -58,7 +58,7 @@ public void test_withConsistencyLevel_async() { public void test_withTenant() { var handle = HANDLE_NONE.withTenant("john_doe"); Assertions.assertThat(handle.tenant()).isEqualTo("john_doe"); - Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isNull(); + Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isEmpty(); } /** @@ -69,6 +69,6 @@ public void test_withTenant() { public void test_withTenant_async() { var handle = HANDLE_NONE_ASYNC.withTenant("john_doe"); Assertions.assertThat(handle.tenant()).isEqualTo("john_doe"); - Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isNull(); + Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isEmpty(); } } diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java index 9cf1e99d9..50286503a 100644 --- a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java +++ b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java @@ -131,18 +131,18 @@ public void test_collectionHandleDefaults_rest(String __, rest.assertNext((method, requestUrl, body, query) -> { switch (clLoc) { case QUERY: - Assertions.assertThat(query).containsEntry("consistency_level", defaults.consistencyLevel()); + Assertions.assertThat(query).containsEntry("consistency_level", defaults.consistencyLevel().get()); break; case BODY: - assertJsonHasValue(body, "consistency_level", defaults.consistencyLevel()); + assertJsonHasValue(body, "consistency_level", defaults.consistencyLevel().get()); } switch (tenantLoc) { case QUERY: - Assertions.assertThat(query).containsEntry("tenant", defaults.tenant()); + Assertions.assertThat(query).containsEntry("tenant", defaults.tenant().get()); break; case BODY: - assertJsonHasValue(body, "tenant", defaults.tenant()); + assertJsonHasValue(body, "tenant", defaults.tenant().get()); } }); } @@ -219,7 +219,7 @@ public void test_defaultTenant_getShards() throws IOException { // Assert rest.assertNext((method, requestUrl, body, query) -> { - Assertions.assertThat(query).containsEntry("tenant", defaults.tenant()); + Assertions.assertThat(query).containsEntry("tenant", defaults.tenant().get()); }); } diff --git a/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java b/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java index ebea2fea7..98af3d227 100644 --- a/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java +++ b/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java @@ -3,16 +3,21 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.OptionalInt; import java.util.concurrent.CompletableFuture; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.util.JsonFormat; +import io.grpc.stub.StreamObserver; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.grpc.Rpc; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamRequest; public class MockGrpcTransport implements GrpcTransport { + private final String host = "example.com"; @FunctionalInterface public interface AssertFunction { @@ -57,4 +62,21 @@ public CompletableFuture perf @Override public void close() throws IOException { } + + @Override + public StreamObserver createStream(StreamObserver recv) { + // TODO(dyma): implement for tests + throw new UnsupportedOperationException("Unimplemented method 'createStream'"); + } + + @Override + public OptionalInt maxMessageSizeBytes() { + // TODO(dyma): implement for tests + throw new UnsupportedOperationException("Unimplemented method 'maxMessageSizeBytes'"); + } + + @Override + public String host() { + return host; + } }