+ * This is determined by the server's {@link Backoff} instruction.
+ */
+ @GuardedBy("this")
+ private int maxSize;
+
+ /**
+ * Maximum size of the serialized message in bytes.
+ * Must be greater that zero.
+ *
+ *
+ * This is determined by the {@link GrpcChannelOptions#maxMessageSize}.
+ */
+ @GuardedBy("this")
+ private final long maxSizeBytes;
+
+ /** Total serialized size of the items in the {@link #buffer}. */
+ @GuardedBy("this")
+ private long sizeBytes;
+
+ /** An in-flight batch is unmodifiable. */
+ @GuardedBy("this")
+ private boolean inFlight = false;
+
+ /**
+ * Pending update to the {@link #maxSize}.
+ *
+ * The value is non-empty when {@link #setMaxSize} is called
+ * while the batch is {@link #inFlight}.
+ */
+ @GuardedBy("this")
+ private OptionalInt pendingMaxSize = OptionalInt.empty();
+
+ Batch(int maxSize, int maxSizeBytes) {
+ assert maxSize > 0 : "non-positive maxSize";
+
+ this.maxSizeBytes = MessageSizeUtil.maxSizeBytes(maxSizeBytes);
+ this.maxSize = maxSize;
+ this.buffer = new LinkedHashMap<>(maxSize); // LinkedHashMap preserves insertion order.
+
+ checkInvariants();
+ }
+
+ /**
+ * Returns true if batch has reached its capacity, either in terms
+ * of the item count or the batch's estimated size in bytes.
+ */
+ synchronized boolean isFull() {
+ return buffer.size() == maxSize || sizeBytes == maxSizeBytes;
+ }
+
+ /**
+ * Returns true if the batch's internal buffer is empty.
+ * If it's primary buffer is empty, its backlog is guaranteed
+ * to be empty as well.
+ */
+ synchronized boolean isEmpty() {
+ return buffer.isEmpty(); // sizeBytes == 0 is guaranteed by class invariant.
+ }
+
+ /**
+ * Prepare a request to be sent. After calling this method, this batch becomes
+ * "in-flight": an attempt to {@link #add} more items to it will be rejected
+ * with an exception.
+ */
+ synchronized Message prepare() {
+ checkInvariants();
+
+ inFlight = true;
+ return builder -> {
+ buffer.forEach((__, data) -> {
+ data.appendTo(builder);
+ });
+ };
+ }
+
+ /**
+ * Set the new {@link #maxSize} for this buffer.
+ *
+ *
extra = buffer.keySet().stream().toList().listIterator();
+ while (extra.hasPrevious() && buffer.size() > maxSize) {
+ addBacklog(buffer.remove(extra.previous()));
+ }
+ } finally {
+ checkInvariants();
+ }
+ }
+
+ /**
+ * Add a data item to the batch.
+ *
+ *
+ * We want to guarantee that, once a work item has been taken from the queue,
+ * it's going to be eventually executed. Because we cannot know if an item
+ * will overflow the batch before it's removed from the queue, the simplest
+ * and safest way to deal with it is to allow {@link Batch} to put
+ * the overflowing item in the {@link #backlog}. The batch is considered
+ * full after that.
+ *
+ * @throws DataTooBigException If the data exceeds the maximum
+ * possible batch size.
+ * @throws IllegalStateException If called on an "in-flight" batch.
+ * @see #prepare
+ * @see #inFlight
+ * @see #clear
+ */
+ synchronized void add(Data data) throws IllegalStateException, DataTooBigException {
+ requireNonNull(data, "data is null");
+ checkInvariants();
+
+ try {
+ if (inFlight) {
+ throw new IllegalStateException("Batch is in-flight");
+ }
+ long remainingBytes = maxSizeBytes - sizeBytes;
+ if (data.sizeBytes() <= remainingBytes) {
+ addSafe(data);
+ return;
+ }
+ if (isEmpty()) {
+ throw new DataTooBigException(data, maxSizeBytes);
+ }
+ // One of the class's invariants is that the backlog must not contain
+ // any items unless the buffer is full. In case this item overflows
+ // the buffer, we put it in the backlog, but pretend the maxSizeBytes
+ // has been reached to satisfy the invariant.
+ // This doubles as a safeguard to ensure the caller cannot add any
+ // more items to the batch before flushing it.
+ addBacklog(data);
+ sizeBytes += remainingBytes;
+ assert isFull() : "batch must be full after an overflow";
+ } finally {
+ checkInvariants();
+ }
+ }
+
+ /**
+ * Add a data item to the batch.
+ *
+ * This method does not check {@link Data#sizeBytes}, so the caller
+ * must ensure that this item will not overflow the batch.
+ */
+ private synchronized void addSafe(Data data) {
+ buffer.put(data.id(), data);
+ sizeBytes += data.sizeBytes();
+ }
+
+ /** Add a data item to the {@link #backlog}. */
+ private synchronized void addBacklog(Data data) {
+ backlog.add(new BacklogItem(data));
+ }
+
+ /**
+ * Clear this batch's internal buffer.
+ *
+ *
+ * Once the buffer is pruned, it is re-populated from the backlog
+ * until the former is full or the latter is exhaused.
+ * If {@link #pendingMaxSize} is not empty, it is applied
+ * before re-populating the buffer.
+ *
+ * @return IDs removed from the buffer.
+ */
+ synchronized Collection clear() {
+ checkInvariants();
+
+ try {
+ inFlight = false;
+
+ Set removed = Set.copyOf(buffer.keySet());
+ buffer.clear();
+ sizeBytes = 0;
+
+ if (pendingMaxSize.isPresent()) {
+ setMaxSize(pendingMaxSize.getAsInt());
+ }
+
+ // Populate internal buffer from the backlog.
+ // We don't need to check the return value of .add(),
+ // as all items in the backlog are guaranteed to not
+ // exceed maxSizeBytes.
+ backlog.stream()
+ .takeWhile(__ -> !isFull())
+ .map(BacklogItem::data)
+ .forEach(this::addSafe);
+
+ return removed;
+ } finally {
+ checkInvariants();
+ }
+ }
+
+ private static record BacklogItem(Data data, Instant createdAt) {
+ public BacklogItem {
+ requireNonNull(data, "data is null");
+ requireNonNull(createdAt, "createdAt is null");
+ }
+
+ /**
+ * This constructor sets {@link #createdAt} automatically.
+ * It is not important that this timestamp is different from
+ * the one in {@link TaskHandle}, as longs as the order is correct.
+ */
+ public BacklogItem(Data data) {
+ this(data, Instant.now());
+ }
+
+ /** Comparator sorts BacklogItems by their creation time. */
+ private static Comparator comparator() {
+ return new Comparator() {
+
+ @Override
+ public int compare(BacklogItem a, BacklogItem b) {
+ if (a.equals(b)) {
+ return 0;
+ }
+
+ int cmpInstant = a.createdAt.compareTo(b.createdAt);
+ boolean sameInstant = cmpInstant == 0;
+ if (sameInstant) {
+ // We cannot return 0 for two items with different
+ // contents, as it may result in data loss.
+ // If they were somehow created in the same instant,
+ // let them be sorted lexicographically.
+ return a.data.id().compareTo(b.data.id());
+ }
+ return cmpInstant;
+ }
+ };
+ }
+ }
+
+ /** Asserts the invariants of this class. */
+ private synchronized void checkInvariants() {
+ assert maxSize > 0 : "non-positive maxSize";
+ assert maxSizeBytes > 0 : "non-positive maxSizeBytes";
+ assert sizeBytes >= 0 : "negative sizeBytes";
+ assert buffer.size() <= maxSize : "buffer exceeds maxSize";
+ assert sizeBytes <= maxSizeBytes : "message exceeds maxSizeBytes";
+ if (!isFull()) {
+ assert backlog.isEmpty() : "backlog not empty when buffer not full";
+ }
+ if (buffer.isEmpty()) {
+ assert sizeBytes == 0 : "sizeBytes must be 0 when buffer is empty";
+ }
+
+ requireNonNull(pendingMaxSize, "pendingMaxSize is null");
+ if (!inFlight) {
+ assert pendingMaxSize.isEmpty() : "open batch has pending maxSize";
+ }
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java
new file mode 100644
index 000000000..d52745976
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/BatchContext.java
@@ -0,0 +1,891 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import static java.util.Objects.requireNonNull;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.EnumSet;
+import java.util.List;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+import java.util.function.BiConsumer;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import io.grpc.Status;
+import io.grpc.stub.StreamObserver;
+import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults;
+import io.weaviate.client6.v1.api.collections.WeaviateObject;
+import io.weaviate.client6.v1.api.collections.batch.Event.ClientError;
+import io.weaviate.client6.v1.api.collections.batch.Event.StreamHangup;
+import io.weaviate.client6.v1.api.collections.data.BatchReference;
+import io.weaviate.client6.v1.api.collections.data.InsertManyRequest;
+import io.weaviate.client6.v1.internal.orm.CollectionDescriptor;
+
+/**
+ * BatchContext stores the state of an active batch process
+ * and controls its lifecycle.
+ *
+ * State
+ *
+ * Lifecycle
+ *
+ * Cancellation policy
+ *
+ * @param the shape of properties for inserted objects.
+ */
+public final class BatchContext implements Closeable {
+ private final int DEFAULT_BATCH_SIZE = 1_000;
+ private final int DEFAULT_QUEUE_SIZE = 100;
+ private final int MAX_RECONNECT_RETRIES = 5;
+
+ private final CollectionDescriptor collectionDescriptor;
+ private final CollectionHandleDefaults collectionHandleDefaults;
+
+ /**
+ * Internal execution service. It's lifecycle is bound to that of the
+ * BatchContext: it's started when the context is initialized
+ * and shutdown on {@link #close}.
+ *
+ *
+ * In the event of abrupt stream termination ({@link Recv#onError} is called),
+ * the "recv" thread MAY shutdown this service in order to interrupt the "send"
+ * thread; the latter may be blocked on {@link Send#awaitCanSend} or
+ * {@link Send#awaitCanPrepareNext}.
+ */
+ private final ExecutorService sendExec = Executors.newSingleThreadExecutor();
+
+ /**
+ * Scheduled thread pool for delayed tasks.
+ *
+ * @see Oom
+ * @see Reconnecting
+ */
+ private final ScheduledExecutorService scheduledExec = Executors.newScheduledThreadPool(1);
+
+ /** The thread that created the context. */
+ private final Thread parent = Thread.currentThread();
+
+ /** Stream factory creates new streams. */
+ private final StreamFactory streamFactory;
+
+ /**
+ * Queue publishes insert tasks from the main thread to the "sender".
+ * It has a maximum capacity of {@link #DEFAULT_QUEUE_SIZE}.
+ *
+ * Send {@link TaskHandle#POISON} to gracefully shutdown the "sender"
+ * thread. The same queue may be re-used with a different "sender",
+ * e.g. after {@link #reconnect}, but only when the new thread is known
+ * to have started. Otherwise the thread trying to put an item on
+ * the queue will block indefinitely.
+ */
+ private final BlockingQueue queue;
+
+ /**
+ * Work-in-progress items.
+ *
+ * An item is added to the {@link #wip} map after the Sender successfully
+ * adds it to the {@link #batch} and is removed once the server reports
+ * back the result (whether success of failure).
+ */
+ private final ConcurrentMap wip = new ConcurrentHashMap<>();
+
+ /**
+ * Current batch.
+ *
+ *
+ * An item is added to the {@link #batch} after the Sender pulls it
+ * from the queue and remains there until it's Ack'ed.
+ */
+ private final Batch batch;
+
+ /**
+ * State encapsulates state-dependent behavior of the {@link BatchContext}.
+ * Before reading {@link #state}, a thread MUST acquire {@link #lock}.
+ */
+ @GuardedBy("lock")
+ private State state;
+ /** lock synchronizes access to {@link #state}. */
+ private final Lock lock = new ReentrantLock();
+ /** stateChanged notifies threads about a state transition. */
+ private final Condition stateChanged = lock.newCondition();
+
+ /**
+ * Client-side part of the current stream, created on {@link #start}.
+ * Other threads MAY use stream but MUST NOT update this field on their own.
+ */
+ private volatile StreamObserver messages;
+
+ /** Handle for the "send" thread. Use {@link Future#cancel} to interrupt it. */
+ private volatile Future> send;
+
+ /**
+ * Latch reaches zero once both "send" (client side) and "recv" (server side)
+ * parts of the stream have closed. After a {@link reconnect}, the latch is
+ * reset.
+ */
+ private volatile CountDownLatch workers;
+
+ /** closing completes the stream. */
+ private final CompletableFuture closing = new CompletableFuture<>();
+
+ /** Executor for performing the shutdown sequence. */
+ private final ExecutorService shutdownExec = Executors.newSingleThreadExecutor();
+
+ /** Lightway check to ensure users cannot send on a closed context. */
+ private volatile boolean closed;
+
+ // /** Closing state. */
+ // private volatile Closing closing;
+
+ // /**
+ // * setClosing trasitions BatchContext to {@link Closing} state exactly once.
+ // * Once this method returns, the caller can call {@code closing.await()}.
+ // */
+ // void setClosing(Exception ex) {
+ // if (closing == null) {
+ // synchronized (Closing.class) {
+ // if (closing == null) {
+ // closing = new Closing(ex);
+ // setState(closing);
+ // }
+ // }
+ // }
+ // }
+
+ BatchContext(
+ StreamFactory streamFactory,
+ int maxSizeBytes,
+ CollectionDescriptor collectionDescriptor,
+ CollectionHandleDefaults collectionHandleDefaults) {
+ this.streamFactory = requireNonNull(streamFactory, "streamFactory is null");
+ this.collectionDescriptor = requireNonNull(collectionDescriptor, "collectionDescriptor is null");
+ this.collectionHandleDefaults = requireNonNull(collectionHandleDefaults, "collectionHandleDefaults is null");
+
+ this.queue = new ArrayBlockingQueue<>(DEFAULT_QUEUE_SIZE);
+ this.batch = new Batch(DEFAULT_BATCH_SIZE, maxSizeBytes);
+ }
+
+ /** Add {@link WeaviateObject} to the batch. */
+ public TaskHandle add(WeaviateObject object) throws InterruptedException {
+ TaskHandle handle = new TaskHandle(
+ object,
+ InsertManyRequest.buildObject(object, collectionDescriptor, collectionHandleDefaults));
+ return add(handle);
+ }
+
+ /** Add {@link BatchReference} to the batch. */
+ public TaskHandle add(BatchReference reference) throws InterruptedException {
+ TaskHandle handle = new TaskHandle(
+ reference,
+ InsertManyRequest.buildReference(reference, collectionHandleDefaults.tenant()));
+ return add(handle);
+ }
+
+ void start() {
+ start(AWAIT_STARTED);
+ }
+
+ void start(State nextState) {
+ workers = new CountDownLatch(2);
+
+ messages = streamFactory.createStream(new Recv());
+
+ // Start the stream and await Started message.
+ messages.onNext(Message.start(collectionHandleDefaults.consistencyLevel()));
+ setState(nextState);
+
+ // "send" routine must start after the nextState has been set.
+ send = sendExec.submit(new Send());
+ }
+
+ /**
+ * Reconnect waits for "send" and "recv" streams to exit
+ * and restarts the process with a new stream.
+ *
+ * @param reconnecting Reconnecting instance that called reconnect.
+ */
+ void reconnect(Reconnecting reconnecting) throws InterruptedException, ExecutionException {
+ workers.await();
+ send.get();
+ start(reconnecting);
+ }
+
+ /**
+ * Retry a task.
+ *
+ * BatchContext does not impose any limit on the number of times a task can
+ * be retried -- it is up to the user to implement an appropriate retry policy.
+ *
+ * @see TaskHandle#timesRetried
+ */
+ public TaskHandle retry(TaskHandle taskHandle) throws InterruptedException {
+ return add(taskHandle.retry());
+ }
+
+ /**
+ * Close attempts to drain the queue and send all remaining items.
+ * Calling any of BatchContext's public methods afterwards will
+ * result in an {@link IllegalStateException}.
+ *
+ * @throws IOException Propagates an exception
+ * if one has occurred in the meantime.
+ */
+ @Override
+ public void close() throws IOException {
+ closed = true;
+
+ try {
+ shutdown();
+ } catch (InterruptedException | ExecutionException e) {
+ if (e instanceof InterruptedException ||
+ e.getCause() instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
+ throw new IOException(e.getCause());
+ } finally {
+ shutdownExecutors();
+ }
+ }
+
+ private void shutdown() throws InterruptedException, ExecutionException {
+ CompletableFuture gracefulShutdown = CompletableFuture.runAsync(() -> {
+ try {
+ // Poison the queue -- this will signal "send" to drain the remaing
+ // items in the batch and in the backlog and exit.
+ //
+ // If shutdownNow has been called previously and the "send" routine
+ // has been interrupted, this may block indefinitely.
+ // However, shutdownNow ensures that `closing` future is resolved.
+ queue.put(TaskHandle.POISON);
+
+ // Wait for the send to exit before closing our end of the stream.
+ send.get();
+ messages.onNext(Message.stop());
+ messages.onCompleted();
+
+ // Wait for both "send" and "recv" to exit.
+ workers.await();
+ closing.complete(null);
+ } catch (Exception e) {
+ closing.completeExceptionally(e);
+ }
+
+ }, shutdownExec);
+
+ // Complete shutdown as soon as one of these futures are completed.
+ // - gracefulShutdown completes if we managed to shutdown normally.
+ // - closing may complete sooner if shutdownNow is called.
+ CompletableFuture.anyOf(closing, gracefulShutdown).get();
+ }
+
+ private void shutdownNow(Exception ex) {
+ // Terminate the "send" routine and wait for it to exit.
+ // Since we're already in the error state we do not care
+ // much if it throws or not.
+ send.cancel(true);
+ try {
+ send.get();
+ } catch (Exception e) {
+ }
+
+ // Now report this error to the server and close the stream.
+ closing.completeExceptionally(ex);
+ messages.onError(Status.INTERNAL.withCause(ex).asRuntimeException());
+
+ // Since shutdownNow is never triggerred by the "main" thread,
+ // it may be blocked on trying to add to the queue. While batch
+ // context is active, we own this thread and may interrupt it.
+ parent.interrupt();
+ }
+
+ private void shutdownExecutors() {
+ BiConsumer> assertEmpty = (name, pending) -> {
+ assert pending.isEmpty() : "'%s' service had %d tasks awaiting execution"
+ .formatted(pending.size(), name);
+ };
+
+ List pending;
+
+ pending = sendExec.shutdownNow();
+ assertEmpty.accept("send", pending);
+
+ pending = scheduledExec.shutdownNow();
+ assertEmpty.accept("oom", pending);
+
+ pending = shutdownExec.shutdownNow();
+ assertEmpty.accept("shutdown", pending);
+ }
+
+ /** Set the new state and notify awaiting threads. */
+ void setState(State nextState) {
+ requireNonNull(nextState, "nextState is null");
+
+ lock.lock();
+ System.out.println("setState " + state + " => " + nextState);
+ try {
+ State prev = state;
+ state = nextState;
+ state.onEnter(prev);
+ stateChanged.signal();
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ /**
+ * onEvent delegates event handling to {@link #state}.
+ *
+ *
+ * Be mindful that most of the time this callback will run in a hot path
+ * on a gRPC thread. {@link State} implementations SHOULD offload any
+ * blocking operations to one of the provided executors.
+ *
+ * @see #scheduledExec
+ */
+ private void onEvent(Event event) {
+ lock.lock();
+ try {
+ System.out.println("onEvent " + event);
+ state.onEvent(event);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ private TaskHandle add(final TaskHandle taskHandle) throws InterruptedException {
+ if (closed) {
+ throw new IllegalStateException("BatchContext is closed");
+ }
+
+ TaskHandle existing = wip.get(taskHandle.id());
+ if (existing != null) {
+ throw new DuplicateTaskException(taskHandle, existing);
+ }
+
+ queue.put(taskHandle);
+ return taskHandle;
+ }
+
+ private final class Send implements Runnable {
+
+ @Override
+ public void run() {
+ try {
+ trySend();
+ } finally {
+ workers.countDown();
+ }
+ }
+
+ /**
+ * trySend consumes {@link #queue} tasks and sends them in batches until it
+ * encounters a {@link TaskHandle#POISON} or is otherwise interrupted.
+ */
+ private void trySend() {
+ try {
+ awaitCanPrepareNext();
+
+ while (!Thread.currentThread().isInterrupted()) {
+ if (batch.isFull()) {
+ System.out.println("==[send batch]==>");
+ send();
+ }
+
+ TaskHandle task = queue.take();
+
+ if (task == TaskHandle.POISON) {
+ System.out.println("took POISON");
+ drain();
+ return;
+ }
+
+ Data data = task.data();
+ batch.add(data);
+
+ TaskHandle existing = wip.put(task.id(), task);
+ assert existing == null : "duplicate tasks in progress, id=" + existing.id();
+ }
+ } catch (InterruptedException ignored) {
+ Thread.currentThread().interrupt();
+ } catch (Exception e) {
+ onEvent(new Event.ClientError(e));
+ return;
+ }
+ }
+
+ /**
+ * Send the current portion of batch items. After this method returns, the batch
+ * is guaranteed to have space for at least one the next item (not full).
+ */
+ private void send() throws InterruptedException {
+ // Continue flushing until we get the batch to not a "not full" state.
+ // This is to account for the backlog, which might re-fill the batch
+ // after .clear().
+ while (batch.isFull()) {
+ flush();
+ }
+ assert !batch.isFull() : "batch is full after send";
+ }
+
+ /**
+ * Send all remaining items in the batch. After this method returns, the batch
+ * is guaranteed to be empty.
+ */
+ private void drain() throws InterruptedException {
+ // To correctly drain the batch, we flush repeatedly
+ // until the batch becomes empty, as clearing a batch
+ // after an ACK might re-populate it from its internal backlog.
+ while (!batch.isEmpty()) {
+ flush();
+ }
+ assert batch.isEmpty() : "batch not empty after drain";
+ }
+
+ private void flush() throws InterruptedException {
+ awaitCanSend();
+ messages.onNext(batch.prepare());
+ setState(IN_FLIGHT);
+
+ // When we get into OOM / ServerShuttingDown state, then we can be certain that
+ // there isn't any reason to keep waiting for the ACKs. However, we should not
+ // exit without either taking a poison pill from the queue,
+ // or being interrupted, as this risks blocking the producer (main) thread.
+ awaitCanPrepareNext();
+ }
+
+ /** Block until the current state allows {@link State#canSend}. */
+ private void awaitCanSend() throws InterruptedException {
+ lock.lock();
+ try {
+ while (!state.canSend()) {
+ stateChanged.await();
+ }
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ /**
+ * Block until the current state allows {@link State#canPrepareNext}.
+ *
+ *
+ * Depending on the BatchContext lifecycle, the semantics of
+ * "await can prepare next" can be one of "message is ACK'ed"
+ * "the stream has started", or, more generally,
+ * "it is safe to take a next item from the queue and add it to the batch".
+ */
+ private void awaitCanPrepareNext() throws InterruptedException {
+ lock.lock();
+ try {
+ while (!state.canPrepareNext()) {
+ stateChanged.await();
+ }
+ } finally {
+ lock.unlock();
+ }
+ }
+ }
+
+ private final class Recv implements StreamObserver {
+
+ @Override
+ public void onNext(Event event) {
+ onEvent(event);
+ }
+
+ /**
+ * EOF for the server-side stream.
+ * By the time this is called, the client-side of the stream had been closed
+ * and the "send" thread has either exited or is on its way there.
+ */
+ @Override
+ public void onCompleted() {
+ try {
+ onEvent(Event.EOF);
+ } finally {
+ workers.countDown();
+ }
+ }
+
+ /** An exception occurred either on our end or in the channel internals. */
+ @Override
+ public void onError(Throwable t) {
+ try {
+ onEvent(Event.StreamHangup.fromThrowable(t));
+ } finally {
+ workers.countDown();
+ }
+ }
+ }
+
+ final State AWAIT_STARTED = new BaseState("AWAIT_STARTED", BaseState.Action.PREPARE_NEXT) {
+ @Override
+ public void onEvent(Event event) {
+ if (requireNonNull(event, "event is null") == Event.STARTED) {
+ setState(ACTIVE);
+ } else {
+ super.onEvent(event);
+ }
+ }
+ };
+ final State ACTIVE = new BaseState("ACTIVE", BaseState.Action.PREPARE_NEXT, BaseState.Action.SEND);
+ final State IN_FLIGHT = new BaseState("IN_FLIGHT") {
+ @Override
+ public void onEvent(Event event) {
+ requireNonNull(event, "event is null");
+
+ if (event instanceof Event.Acks acks) {
+ Collection removed = batch.clear();
+ if (!acks.acked().containsAll(removed)) {
+ throw ProtocolViolationException.incompleteAcks(List.copyOf(removed));
+ }
+ acks.acked().forEach(id -> {
+ TaskHandle task = wip.get(id);
+ if (task != null) {
+ task.setAcked();
+ }
+ });
+ setState(ACTIVE);
+ } else if (event instanceof Event.Oom oom) {
+ setState(new Oom(oom.delaySeconds()));
+ } else {
+ super.onEvent(event);
+ }
+ }
+ };
+
+ private class BaseState implements State {
+ /** State's display name for logging. */
+ private final String name;
+ /** Actions permitted in this state. */
+ private final EnumSet permitted;
+
+ enum Action {
+ /**
+ * Thy system is allowed to accept new items from the user
+ * and populate the next batch.
+ */
+ PREPARE_NEXT,
+
+ /** The system is allowed to send the next batch once it's ready. */
+ SEND;
+ }
+
+ /**
+ * @param name Display name.
+ * @param permitted Actions permitted in this state.
+ */
+ protected BaseState(String name, Action... permitted) {
+ this.name = name;
+ this.permitted = requireNonNull(permitted, "actions is null").length == 0
+ ? EnumSet.noneOf(Action.class)
+ : EnumSet.copyOf(Arrays.asList(permitted));
+ }
+
+ @Override
+ public void onEnter(State prev) {
+ }
+
+ @Override
+ public boolean canSend() {
+ return permitted.contains(Action.SEND);
+ }
+
+ @Override
+ public boolean canPrepareNext() {
+ return permitted.contains(Action.PREPARE_NEXT);
+ }
+
+ /**
+ * Handle events which may arrive at any moment without violating the protocol.
+ *
+ *
+ * - {@link Event.Results} -- update tasks in {@link #wip} and remove them.
+ *
- {@link Event.Backoff} -- adjust batch size.
+ *
- {@link Event#SHUTTING_DOWN} -- transition into
+ * {@link ServerShuttingDown}.
+ *
- {@link Event.StreamHangup -- transition into {@link Reconnecting} state.
+ *
- {@link Event.ClientError -- shutdown the service immediately.
+ *
+ *
+ * @throws ProtocolViolationException If event cannot be handled in this state.
+ * @see BatchContext#shutdownNow
+ */
+ @Override
+ public void onEvent(Event event) {
+ requireNonNull(event, "event is null");
+
+ if (event instanceof Event.Results results) {
+ onResults(results);
+ } else if (event instanceof Event.Backoff backoff) {
+ onBackoff(backoff);
+ } else if (event == Event.SHUTTING_DOWN) {
+ onShuttingDown();
+ } else if (event instanceof Event.StreamHangup || event == Event.EOF) {
+ onStreamClosed(event);
+ } else if (event instanceof Event.ClientError error) {
+ onClientError(error);
+ } else {
+ throw ProtocolViolationException.illegalStateTransition(this, event);
+ }
+ }
+
+ private final void onResults(Event.Results results) {
+ results.successful().forEach(id -> wip.remove(id).setSuccess());
+ results.errors().forEach((id, error) -> wip.remove(id).setError(error));
+ }
+
+ private final void onBackoff(Event.Backoff backoff) {
+ System.out.print("========== BACKOFF ==============");
+ System.out.print(backoff.maxSize());
+ System.out.print("=================================");
+ batch.setMaxSize(backoff.maxSize());
+ }
+
+ private final void onShuttingDown() {
+ setState(new ServerShuttingDown(this));
+ }
+
+ private final void onStreamClosed(Event event) {
+ if (event instanceof Event.StreamHangup hangup) {
+ hangup.exception().printStackTrace();
+ }
+ if (!send.isDone()) {
+ setState(new Reconnecting(MAX_RECONNECT_RETRIES));
+ }
+ }
+
+ private final void onClientError(Event.ClientError error) {
+ shutdownNow(error.exception());
+ }
+
+ @Override
+ public String toString() {
+ return name;
+ }
+ }
+
+ /**
+ * Oom waits for {@link Event#SHUTTING_DOWN} up to a specified amount of time,
+ * after which it will force stream termiation by imitating server shutdown.
+ */
+ private final class Oom extends BaseState {
+ private final long delaySeconds;
+ private ScheduledFuture> shutdown;
+
+ private Oom(long delaySeconds) {
+ super("OOM");
+ this.delaySeconds = delaySeconds;
+ }
+
+ @Override
+ public void onEnter(State prev) {
+ shutdown = scheduledExec.schedule(this::initiateShutdown, delaySeconds, TimeUnit.SECONDS);
+ }
+
+ /** Imitate server shutdown sequence. */
+ private void initiateShutdown() {
+ // We cannot route event handling via normal BatchContext#onEvent, because
+ // it delegates to the current state, which is Oom. If Oom#onEvent were to
+ // receive an Event.SHUTTING_DOWN, it would cancel this execution of this
+ // very sequence. Instead, we delegate to our parent BaseState which normally
+ // handles these events.
+ if (Thread.currentThread().isInterrupted()) {
+ super.onEvent(Event.SHUTTING_DOWN);
+ }
+ if (Thread.currentThread().isInterrupted()) {
+ super.onEvent(Event.EOF);
+ }
+ }
+
+ @Override
+ public void onEvent(Event event) {
+ requireNonNull(event, "event");
+ if (event == Event.SHUTTING_DOWN ||
+ event instanceof StreamHangup ||
+ event instanceof ClientError) {
+ shutdown.cancel(true);
+ try {
+ shutdown.get();
+ } catch (CancellationException ignored) {
+ } catch (InterruptedException ignored) {
+ // Recv is running on a thread from gRPC's internal thread pool,
+ // so, while onEvent allows InterruptedException to stay responsive,
+ // in practice this thread will only be interrupted by the thread pool,
+ // which already knows it's being shut down.
+ } catch (ExecutionException ex) {
+ onEvent(new Event.ClientError(ex));
+ }
+ }
+ super.onEvent(event);
+ }
+ }
+
+ /**
+ * ServerShuttingDown allows preparing the next batch
+ * unless the server's OOM'ed on the previous one.
+ * Once set, the state will shutdown {@link BatchContext#sendExec}
+ * to instruct the "send" thread to close our part of the stream.
+ */
+ private final class ServerShuttingDown extends BaseState {
+ private final boolean canPrepareNext;
+
+ private ServerShuttingDown(State previous) {
+ super("SERVER_SHUTTING_DOWN");
+ this.canPrepareNext = requireNonNull(previous, "previous is null").getClass() != Oom.class;
+ }
+
+ @Override
+ public boolean canPrepareNext() {
+ return canPrepareNext;
+ }
+
+ @Override
+ public boolean canSend() {
+ return false;
+ }
+
+ @Override
+ public void onEnter(State prev) {
+ send.cancel(true);
+ }
+ }
+
+ /**
+ * Reconnecting state is entererd either by the server finishing a shutdown
+ * and closing it's end of the stream or an unexpected stream hangup.
+ *
+ * @see Recv#onCompleted graceful server shutdown
+ * @see Recv#onError stream hangup
+ */
+ private final class Reconnecting extends BaseState {
+ private final int maxRetries;
+ private int retries = 0;
+
+ private Reconnecting(int maxRetries) {
+ super("RECONNECTING", Action.PREPARE_NEXT);
+ this.maxRetries = maxRetries;
+ }
+
+ @Override
+ public void onEnter(State prev) {
+ // The reconnected state is re-set every time the stream restarts.
+ // This ensures that onEnter hook is only called the first
+ // time we enter Reconnecting state.
+ if (prev == this) {
+ return;
+ }
+
+ send.cancel(true);
+
+ if (prev.getClass() != ServerShuttingDown.class) {
+ // This is NOT an orderly shutdown, we're reconnecting after a stream hangup.
+ // Assume all WIP items have been lost and re-submit everything.
+ // All items in the batch are contained in WIP, so it is safe to discard the
+ // batch entirely and re-populate from WIP.
+ while (!batch.isEmpty()) {
+ batch.clear();
+ }
+
+ // Unlike during normal operation, we will not stop when batch.isFull().
+ // Batch#add guarantees that data will not be discarded in the event of
+ // an overflow -- all extra items are placed into the backlog, which is
+ // unbounded.
+ wip.values().forEach(task -> batch.add(task.data()));
+ }
+
+ reconnectNow();
+ }
+
+ @Override
+ public void onEvent(Event event) {
+ assert retries <= maxRetries : "maxRetries exceeded";
+
+ if (event == Event.STARTED) {
+ setState(ACTIVE);
+ } else if (event instanceof Event.StreamHangup) {
+ if (retries == maxRetries) {
+ onEvent(new Event.ClientError(new IOException("Server unavailable")));
+ } else {
+ reconnectAfter(1 * 2 ^ retries);
+ }
+ }
+
+ assert retries <= maxRetries : "maxRetries exceeded";
+ }
+
+ /** Reconnect with no delay. */
+ private void reconnectNow() {
+ reconnectAfter(0);
+ }
+
+ /**
+ * Schedule a task to {@link #reconnect} after a delay.
+ *
+ * @param delaySeconds Delay in seconds.
+ *
+ * @apiNote The task is scheduled on {@link #scheduledExec} even if
+ * {@code delaySeconds == 0} to avoid blocking gRPC worker thread,
+ * where the {@link BatchContext#onEvent} callback runs.
+ */
+ private void reconnectAfter(long delaySeconds) {
+ retries++;
+
+ scheduledExec.schedule(() -> {
+ try {
+ reconnect(this);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ } catch (ExecutionException e) {
+ onEvent(new Event.ClientError(e));
+ }
+ }, delaySeconds, TimeUnit.SECONDS);
+ }
+ }
+
+ // --------------------------------------------------------------------------
+
+ private final ScheduledExecutorService reconnectExec = Executors.newScheduledThreadPool(1);
+
+ void scheduleReconnect(int reconnectIntervalSeconds) {
+ reconnectExec.scheduleWithFixedDelay(() -> {
+ if (Thread.currentThread().isInterrupted()) {
+ onEvent(Event.SHUTTING_DOWN);
+ }
+ if (Thread.currentThread().isInterrupted()) {
+ onEvent(Event.EOF);
+ }
+
+ // We want to count down from the moment we re-opened the stream,
+ // not from the moment we initialited the sequence.
+ lock.lock();
+ try {
+ while (state != ACTIVE) {
+ stateChanged.await();
+ }
+ } catch (InterruptedException ignored) {
+ // Let the process exit normally.
+ } finally {
+ lock.unlock();
+ }
+ }, reconnectIntervalSeconds, reconnectIntervalSeconds, TimeUnit.SECONDS);
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java
new file mode 100644
index 000000000..dba9e79ad
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Data.java
@@ -0,0 +1,94 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import static java.util.Objects.requireNonNull;
+
+import javax.annotation.concurrent.Immutable;
+
+import com.google.protobuf.GeneratedMessage;
+import com.google.protobuf.GeneratedMessageV3;
+
+import io.weaviate.client6.v1.api.collections.WeaviateObject;
+import io.weaviate.client6.v1.api.collections.data.ObjectReference;
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch;
+
+@Immutable
+@SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3
+class Data implements Message {
+
+ /**
+ * Raw input value, as passed by the user.
+ */
+ private final Object raw;
+
+ /**
+ * Task ID. Depending on the underlying object, this will either be
+ * {@link WeaviateObject#uuid} or {@link ObjectReference#beacon}.
+ *
+ * Since UUIDs and beacons cannot clash, ID does not encode any
+ * information about the underlying data type.
+ */
+ private final String id;
+
+ /**
+ * Serialized representation of the {@link #raw}. This valus is immutable
+ * for the entire lifecycle of the handle.
+ */
+ private final GeneratedMessage.ExtendableMessage message;
+
+ /** Estimated size of the {@link #message} when serialized. */
+ private final int sizeBytes;
+
+ enum Type {
+ OBJECT(WeaviateProtoBatch.BatchStreamRequest.Data.OBJECTS_FIELD_NUMBER),
+ REFERENCE(WeaviateProtoBatch.BatchStreamRequest.Data.REFERENCES_FIELD_NUMBER);
+
+ private final int fieldNumber;
+
+ private Type(int fieldNumber) {
+ this.fieldNumber = fieldNumber;
+ }
+
+ public int fieldNumber() {
+ return fieldNumber;
+ }
+ }
+
+ private Data(Object raw, String id, GeneratedMessage.ExtendableMessage message, int sizeBytes) {
+ this.raw = requireNonNull(raw, "raw is null");
+ this.id = requireNonNull(id, "id is null");
+ this.message = requireNonNull(message, "message is null");
+
+ assert sizeBytes >= 0;
+ this.sizeBytes = sizeBytes;
+ }
+
+ Data(Object raw, String id, GeneratedMessage.ExtendableMessage message,
+ Type type) {
+ this(raw, id, message, MessageSizeUtil.ofDataField(message, type));
+ }
+
+ String id() {
+ return id;
+ }
+
+ /** Serialized data size in bytes. */
+ int sizeBytes() {
+ return sizeBytes;
+ }
+
+ @Override
+ public void appendTo(WeaviateProtoBatch.BatchStreamRequest.Builder builder) {
+ WeaviateProtoBatch.BatchStreamRequest.Data.Builder data = requireNonNull(builder, "builder is null")
+ .getDataBuilder();
+ if (message instanceof WeaviateProtoBatch.BatchObject object) {
+ data.getObjectsBuilder().addValues(object);
+ } else if (message instanceof WeaviateProtoBatch.BatchReference ref) {
+ data.getReferencesBuilder().addValues(ref);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "%s (%s)".formatted(raw.getClass().getSimpleName(), id);
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java
new file mode 100644
index 000000000..5160bd27c
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DataTooBigException.java
@@ -0,0 +1,16 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import static java.util.Objects.requireNonNull;
+
+import io.weaviate.client6.v1.api.WeaviateException;
+
+/**
+ * DataTooBigException is thrown when a single object exceeds
+ * the maximum size of a gRPC message.
+ */
+public class DataTooBigException extends WeaviateException {
+ DataTooBigException(Data data, long maxSizeBytes) {
+ super("%s with size=%dB exceeds maximum message size %dB".formatted(
+ requireNonNull(data, "data is null"), data.sizeBytes(), maxSizeBytes));
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/DuplicateTaskException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DuplicateTaskException.java
new file mode 100644
index 000000000..f0949a6c0
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/DuplicateTaskException.java
@@ -0,0 +1,23 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import io.weaviate.client6.v1.api.WeaviateException;
+
+/**
+ * DuplicateTaskException is thrown if task is submitted to the batch
+ * while another task with the same ID is in progress.
+ */
+public class DuplicateTaskException extends WeaviateException {
+ private final TaskHandle existing;
+
+ DuplicateTaskException(TaskHandle duplicate, TaskHandle existing) {
+ super("%s cannot be added to the batch while another task with the same ID is in progress");
+ this.existing = existing;
+ }
+
+ /**
+ * Get the currently in-progress handle that's a duplicate of the one submitted.
+ */
+ public TaskHandle getExisting() {
+ return existing;
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java
new file mode 100644
index 000000000..ba3f6b208
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Event.java
@@ -0,0 +1,162 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import static java.util.Objects.requireNonNull;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+
+import io.grpc.Status;
+import io.weaviate.client6.v1.api.collections.batch.Event.Acks;
+import io.weaviate.client6.v1.api.collections.batch.Event.Backoff;
+import io.weaviate.client6.v1.api.collections.batch.Event.ClientError;
+import io.weaviate.client6.v1.api.collections.batch.Event.Oom;
+import io.weaviate.client6.v1.api.collections.batch.Event.Results;
+import io.weaviate.client6.v1.api.collections.batch.Event.Started;
+import io.weaviate.client6.v1.api.collections.batch.Event.StreamHangup;
+import io.weaviate.client6.v1.api.collections.batch.Event.TerminationEvent;
+
+sealed interface Event
+ permits Started, Acks, Results, Backoff, Oom, TerminationEvent, StreamHangup, ClientError {
+
+ final static Event STARTED = new Started();
+ final static Event SHUTTING_DOWN = TerminationEvent.SHUTTING_DOWN;
+ final static Event EOF = TerminationEvent.EOF;
+
+ /**
+ * The server has acknowledged our Start message and is ready to receive data.
+ */
+ record Started() implements Event {
+ }
+
+ /**
+ * The server has added items from the previous message to its internal
+ * work queue, client MAY send the next batch.
+ *
+ *
+ * The protocol guarantess that {@link Acks} will contain IDs for all
+ * items sent in the previous batch.
+ */
+ record Acks(Collection acked) implements Event {
+ public Acks {
+ acked = List.copyOf(requireNonNull(acked, "acked is null"));
+ }
+
+ @Override
+ public String toString() {
+ return "Acks";
+ }
+ }
+
+ /**
+ * Results for the insertion of a previous batches.
+ *
+ *
+ * We assume that the server may return partial results, or return
+ * results out of the order of inserting messages.
+ */
+ record Results(Collection successful, Map errors) implements Event {
+ public Results {
+ successful = List.copyOf(requireNonNull(successful, "successful is null"));
+ errors = Map.copyOf(requireNonNull(errors, "errors is null"));
+ }
+
+ @Override
+ public String toString() {
+ return "Results";
+ }
+ }
+
+ /**
+ * Backoff communicates the optimal batch size (number of objects)
+ * with respect to the current load on the server.
+ *
+ *
+ * Backoff is an instruction, not a recommendation.
+ * On receiving this message, the client must ensure that
+ * all messages it produces, including the one being prepared,
+ * do not exceed the size limit indicated by {@link #maxSize}
+ * until the server sends another Backoff message. The limit
+ * MUST also be respected after a {@link BatchContext#reconnect}.
+ *
+ *
+ * The client MAY use the latest {@link #maxSize} as the default
+ * message limit in a new {@link BatchContext}, but is not required to.
+ */
+ record Backoff(int maxSize) implements Event {
+
+ @Override
+ public String toString() {
+ return "Backoff";
+ }
+ }
+
+ /**
+ * Out-Of-Memory.
+ *
+ *
+ * Items sent in the previous request cannot be accepted,
+ * as inserting them may exhaust server's available disk space.
+ * On receiving this message, the client MUST stop producing
+ * messages immediately and await {@link #SHUTTING_DOWN} event.
+ *
+ *
+ * Oom is the sibling of {@link Acks} with the opposite effect.
+ * The protocol guarantees that the server will respond with either of
+ * the two, but never both.
+ */
+ record Oom(int delaySeconds) implements Event {
+ }
+
+ /** Events that are part of the server's graceful shutdown strategy. */
+ enum TerminationEvent implements Event {
+ /**
+ * Server shutdown in progress.
+ *
+ *
+ * The server began the process of gracefull shutdown, due to a
+ * scale-up event (if it previously reported {@link #OOM}) or
+ * some other external event.
+ * On receiving this message, the client MUST stop producing
+ * messages immediately, close it's side of the stream, and
+ * continue readings server's messages until {@link #EOF}.
+ */
+ SHUTTING_DOWN,
+
+ /**
+ * Stream EOF.
+ *
+ *
+ * The server has will not receive any messages. If the client
+ * has more data to send, it SHOULD re-connect to another instance
+ * by re-opening the stream and continue processing the batch.
+ * If the client has previously sent {@link Message#STOP}, it can
+ * safely exit.
+ */
+ EOF;
+ }
+
+ /**
+ * StreamHangup means the RPC is "dead": the stream is closed
+ * and using it will result in an {@link IllegalStateException}.
+ */
+ record StreamHangup(Exception exception) implements Event {
+ static StreamHangup fromThrowable(Throwable t) {
+ Status status = Status.fromThrowable(t);
+ return new StreamHangup(status.asException());
+ }
+ }
+
+ /**
+ * ClientError means a client-side exception has happened,
+ * and is meant primarily for the "send" thread to propagate
+ * any exception it might catch.
+ *
+ *
+ * This MUST be treated as an irrecoverable condition, because
+ * it is likely caused by an internal issue (an NPE) or a bad
+ * input ({@link DataTooBigException}).
+ */
+ record ClientError(Exception exception) implements Event {
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java
new file mode 100644
index 000000000..9d387d1b1
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/Message.java
@@ -0,0 +1,32 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import static java.util.Objects.requireNonNull;
+
+import java.util.Optional;
+
+import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel;
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch;
+
+@FunctionalInterface
+interface Message {
+ void appendTo(WeaviateProtoBatch.BatchStreamRequest.Builder builder);
+
+ /** Create a Start message. */
+ static Message start(Optional consistencyLevel) {
+ requireNonNull(consistencyLevel, "consistencyLevel is null");
+
+ final WeaviateProtoBatch.BatchStreamRequest.Start.Builder start = WeaviateProtoBatch.BatchStreamRequest.Start
+ .newBuilder();
+ consistencyLevel.ifPresent(value -> value.appendTo(start));
+ return builder -> builder.setStart(start);
+ }
+
+ /** Create a Stop message. */
+ static Message stop() {
+ return STOP;
+ }
+
+ static final Message STOP = builder -> builder
+ .setStop(WeaviateProtoBatch.BatchStreamRequest.Stop.getDefaultInstance());
+
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/MessageSizeUtil.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/MessageSizeUtil.java
new file mode 100644
index 000000000..f3c934232
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/MessageSizeUtil.java
@@ -0,0 +1,44 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import static java.util.Objects.requireNonNull;
+
+import com.google.protobuf.CodedOutputStream;
+import com.google.protobuf.GeneratedMessage;
+import com.google.protobuf.GeneratedMessageV3;
+
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch;
+
+final class MessageSizeUtil {
+ private static int DATA_TAG_SIZE = CodedOutputStream
+ .computeTagSize(WeaviateProtoBatch.BatchStreamRequest.DATA_FIELD_NUMBER);
+
+ private MessageSizeUtil() {
+ }
+
+ /**
+ * Adjust batch byte-size limit to account for the
+ * {@link WeaviateProtoBatch.BatchStreamRequest.Data} container.
+ *
+ *
+ * A protobuf field has layout {@code [tag][lenght(payload)][payload]},
+ * so to estimate the batch size correctly we must account for "tag"
+ * and "length", not just the raw payload.
+ */
+ static long maxSizeBytes(long maxSizeBytes) {
+ if (maxSizeBytes <= DATA_TAG_SIZE) {
+ throw new IllegalArgumentException("Maximum batch size must be at least %dB".formatted(DATA_TAG_SIZE));
+ }
+ return maxSizeBytes - DATA_TAG_SIZE;
+ }
+
+ /**
+ * Calculate the size of a serialized
+ * {@link WeaviateProtoBatch.BatchStreamRequest.Data} field.
+ */
+ @SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3
+ static int ofDataField(GeneratedMessage.ExtendableMessage message, Data.Type type) {
+ requireNonNull(type, "type is null");
+ requireNonNull(message, "message is null");
+ return CodedOutputStream.computeMessageSize(type.fieldNumber(), message);
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/ProtocolViolationException.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/ProtocolViolationException.java
new file mode 100644
index 000000000..d97b4839c
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/ProtocolViolationException.java
@@ -0,0 +1,48 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import static java.util.Objects.requireNonNull;
+
+import java.util.List;
+
+import io.weaviate.client6.v1.api.WeaviateException;
+
+/**
+ * ProtocolViolationException describes unexpected server behavior in violation
+ * of the SSB protocol.
+ *
+ *
+ * This exception cannot be handled in a meaningful way and should be reported
+ * to the upstream Weaviate
+ * project.
+ */
+public class ProtocolViolationException extends WeaviateException {
+ ProtocolViolationException(String message) {
+ super(message);
+ }
+
+ /**
+ * Protocol violated because an event arrived while the client is in a state
+ * which doesn't expect to handle this event.
+ *
+ * @param current Current {@link BatchContext} state.
+ * @param event Server-side event.
+ * @return ProtocolViolationException with a formatted message.
+ */
+ static ProtocolViolationException illegalStateTransition(State current, Event event) {
+ return new ProtocolViolationException("%s arrived in %s state".formatted(event, current));
+ }
+
+ /**
+ * Protocol violated because some tasks from the previous Data message
+ * are not present in the Acks message.
+ *
+ * @param remaining IDs of the tasks that weren't ack'ed. MUST be non-empty.
+ * @return ProtocolViolationException with a formatted message.
+ */
+ static ProtocolViolationException incompleteAcks(List remaining) {
+ requireNonNull(remaining, "remaining is null");
+ return new ProtocolViolationException("IDs from previous Data message missing in Acks: '%s', ... (%d more)"
+ .formatted(remaining.get(0), remaining.size() - 1));
+
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java
new file mode 100644
index 000000000..a4ba41310
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/State.java
@@ -0,0 +1,42 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+interface State {
+ /**
+ * canSend returns a boolean indicating if sending
+ * an "insert" message is allowed in this state.
+ */
+ boolean canSend();
+
+ /**
+ * canPrepareNext returns a boolean indicating if accepting
+ * more items into the batch is allowed in this state.
+ */
+ boolean canPrepareNext();
+
+ /**
+ * Lifecycle hook that's called after the state is set.
+ *
+ *
+ *
+ * This hook MUST be called exactly once.
+ *
+ * The next state MUST NOT be set until onEnter returns.
+ *
+ * @param prev Previous state or null.
+ */
+ void onEnter(State prev);
+
+ /**
+ * onEvent handles incoming events; these can be generated by the server
+ * or by a different part of the program -- the {@link State} MUST NOT
+ * make any assumptions about the event's origin.
+ *
+ *
+ * How the event is handled is up to the concrete implementation.
+ * It may modify {@link BatchContext} internal state, via one of it's
+ * package-private methods, including transitioning the context to a
+ * different state via {@link BatchContext#setState(State)}, or start
+ * a separate process, e.g. the OOM timer.
+ */
+ void onEvent(Event event);
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java
new file mode 100644
index 000000000..c50eec6ee
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/StreamFactory.java
@@ -0,0 +1,13 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import io.grpc.stub.StreamObserver;
+
+/**
+ * @param the type of the object sent down the stream.
+ * @param the type of the object received from the stream.
+ */
+@FunctionalInterface
+interface StreamFactory {
+ /** Create a new stream for the send-recv observer pair. */
+ StreamObserver createStream(StreamObserver recv);
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java
new file mode 100644
index 000000000..aad9c8a71
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TaskHandle.java
@@ -0,0 +1,176 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import static java.util.Objects.requireNonNull;
+
+import java.time.Instant;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+
+import javax.annotation.concurrent.ThreadSafe;
+
+import com.google.protobuf.GeneratedMessage;
+import com.google.protobuf.GeneratedMessageV3;
+
+import io.weaviate.client6.v1.api.collections.WeaviateObject;
+import io.weaviate.client6.v1.api.collections.data.BatchReference;
+
+@ThreadSafe
+@SuppressWarnings("deprecation") // protoc uses GeneratedMessageV3
+public final class TaskHandle {
+ static final TaskHandle POISON = new TaskHandle();
+
+ /**
+ * Input value as passed by the user.
+ *
+ *
+ * Changes in the {@link #raw}'s underlying value will not be reflected
+ * in the {@link TaskHandle} (e.g. the serialized version is not updated),
+ * so users SHOULD treat items passed to and retrieved from {@link TaskHandle}
+ * as effectively ummodifiable.
+ */
+ private final Data data;
+
+ /** Flag indicatig the task has been ack'ed. */
+ private final CompletableFuture acked = new CompletableFuture<>();
+
+ public final record Result(Optional error) {
+ public Result {
+ requireNonNull(error, "error is null");
+ }
+ }
+
+ /**
+ * Task result completes when the client receives {@link Event.Results}
+ * containing this handle's {@link #id}.
+ */
+ private final CompletableFuture result = new CompletableFuture<>();
+
+ /** The number of times this task has been retried. */
+ private final int retries;
+
+ /** Task creation timestamp. */
+ private final Instant createdAt = Instant.now();
+
+ private TaskHandle(Data data, int retries) {
+ this.data = requireNonNull(data, "data is null");
+
+ assert retries >= 0 : "negative retries";
+ this.retries = retries;
+ }
+
+ /** Constructor for {@link WeaviateObject}. */
+ TaskHandle(WeaviateObject> object, GeneratedMessage.ExtendableMessage data) {
+ this(new Data(object, object.uuid(), data, Data.Type.OBJECT), 0);
+ }
+
+ /** Constructor for {@link BatchReference}. */
+ TaskHandle(BatchReference reference, GeneratedMessage.ExtendableMessage data) {
+ this(new Data(reference, reference.target().beacon(), data, Data.Type.REFERENCE), 0);
+ }
+
+ /**
+ * Poison pill constructor.
+ *
+ *
+ * A handle created with this constructor should not be
+ * used for anything other that direct comparison using {@code ==} operator;
+ * calling any method on a poison pill is likely to result in a
+ * {@link NullPointerException} being thrown.
+ */
+ private TaskHandle() {
+ this.data = null;
+ this.retries = 0;
+ }
+
+ /**
+ * Creates a new task containing the same data as this task and {@link retries}
+ * counter incremented by 1. The {@link acked} and {@link result} futures
+ * are not copied to the returned task.
+ *
+ * @return Task handle.
+ */
+ TaskHandle retry() {
+ return new TaskHandle(data, retries + 1);
+ }
+
+ String id() {
+ return data.id();
+ }
+
+ Data data() {
+ return data;
+ }
+
+ /** Set the {@link #acked} flag. */
+ void setAcked() {
+ acked.complete(null);
+ }
+
+ /**
+ * Mark the task successful. This status cannot be changed, so calling
+ * {@link #setError} afterwards will have no effect.
+ */
+ void setSuccess() {
+ setResult(new Result(Optional.empty()));
+ }
+
+ /**
+ * Mark the task failed. This status cannot be changed, so calling
+ * {@link #setSuccess} afterwards will have no effect.
+ *
+ * @param error Error message. Null values are tolerated, but are only expected
+ * to occurr due to a server's mistake.
+ * Do not use {@code setError(null)} if the server reports success
+ * status for the task; prefer {@link #setSuccess} in that case.
+ */
+ void setError(String error) {
+ setResult(new Result(Optional.ofNullable(error)));
+ }
+
+ /**
+ * Set result for this task.
+ *
+ * @throws IllegalStateException if the task has not been ack'ed.
+ */
+ private void setResult(Result result) {
+ if (!acked.isDone()) {
+ // TODO(dyma): can this happen due to us?
+ throw new IllegalStateException("Result can only be set for an ack'ed task");
+ }
+ this.result.complete(result);
+ }
+
+ /**
+ * Check if the task has been accepted.
+ *
+ * @return A future which completes when the server has accepted the task.
+ */
+ public CompletableFuture isAcked() {
+ return acked;
+ }
+
+ /**
+ * Retrieve the result for this task.
+ *
+ * @return A future which completes when the server
+ * has reported the result for this task.
+ */
+ public CompletableFuture result() {
+ return result;
+ }
+
+ /**
+ * Number of times this task has been retried. Since {@link TaskHandle} is
+ * immutable, this value does not change, but retrying a task via
+ * {@link BatchContext#retry} is reflected in the returned handle's
+ * {@link #timesRetried}.
+ */
+ public int timesRetried() {
+ return retries;
+ }
+
+ @Override
+ public String toString() {
+ return "TaskHandle".formatted(id(), retries, createdAt);
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java
new file mode 100644
index 000000000..2e6344ada
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/TranslatingStreamFactory.java
@@ -0,0 +1,147 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import static java.util.Objects.requireNonNull;
+
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import io.grpc.stub.StreamObserver;
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc;
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch;
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply;
+
+/**
+ * TranslatingStreamFactory is an adaptor for the
+ * {@link WeaviateGrpc.WeaviateStub#batchStream} factory. The returned stream
+ * translates client-side messages into protobuf requests and server-side
+ * replies into events.
+ *
+ * @see Message
+ * @see Event
+ */
+class TranslatingStreamFactory implements StreamFactory {
+ private final StreamFactory protoFactory;
+
+ TranslatingStreamFactory(
+ StreamFactory protoFactory) {
+ this.protoFactory = requireNonNull(protoFactory, "protoFactory is null");
+ }
+
+ @Override
+ public StreamObserver createStream(StreamObserver recv) {
+ return new Messenger(protoFactory.createStream(new Eventer(recv)));
+ }
+
+ /**
+ * DelegatingStreamObserver delegates {@link #onCompleted} and {@link #onError}
+ * to another observer and translates the messages in {@link #onNext}.
+ *
+ * @param the type of the incoming message.
+ * @param the type of the message handed to the delegate.
+ */
+ private abstract class DelegatingStreamObserver implements StreamObserver {
+ protected final StreamObserver delegate;
+
+ protected DelegatingStreamObserver(StreamObserver delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void onCompleted() {
+ delegate.onCompleted();
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ delegate.onError(t);
+ }
+ }
+
+ /**
+ * Messeger translates client's messages into batch stream requests.
+ *
+ * @see Message
+ */
+ private final class Messenger extends DelegatingStreamObserver {
+ private Messenger(StreamObserver delegate) {
+ super(delegate);
+ }
+
+ @Override
+ public void onNext(Message message) {
+ WeaviateProtoBatch.BatchStreamRequest.Builder builder = WeaviateProtoBatch.BatchStreamRequest.newBuilder();
+ message.appendTo(builder);
+ delegate.onNext(builder.build());
+ }
+ }
+
+ /**
+ * Eventer translates server replies into events.
+ *
+ * @see Event
+ */
+ private final class Eventer extends DelegatingStreamObserver {
+ private Eventer(StreamObserver delegate) {
+ super(delegate);
+ }
+
+ @Override
+ public void onNext(BatchStreamReply reply) {
+ Event event = null;
+ switch (reply.getMessageCase()) {
+ case STARTED:
+ event = Event.STARTED;
+ break;
+ case SHUTTING_DOWN:
+ event = Event.SHUTTING_DOWN;
+ break;
+ case SHUTDOWN:
+ event = Event.EOF;
+ break;
+ case OUT_OF_MEMORY:
+ // TODO(dyma): read this value from the message
+ event = new Event.Oom(300);
+ break;
+ case BACKOFF:
+ event = new Event.Backoff(reply.getBackoff().getBatchSize());
+ break;
+ case ACKS:
+ Stream uuids = reply.getAcks().getUuidsList().stream();
+ Stream beacons = reply.getAcks().getBeaconsList().stream();
+ event = new Event.Acks(Stream.concat(uuids, beacons).toList());
+ break;
+ case RESULTS:
+ List successful = reply.getResults().getSuccessesList().stream()
+ .map(detail -> {
+ if (detail.hasUuid()) {
+ return detail.getUuid();
+ } else if (detail.hasBeacon()) {
+ return detail.getBeacon();
+ }
+ throw new IllegalArgumentException("Result has neither UUID nor a beacon");
+ })
+ .toList();
+
+ Map errors = reply.getResults().getErrorsList().stream()
+ .map(detail -> {
+ String error = requireNonNull(detail.getError(), "error is null");
+ if (detail.hasUuid()) {
+ return Map.entry(detail.getUuid(), error);
+ } else if (detail.hasBeacon()) {
+ return Map.entry(detail.getBeacon(), error);
+ }
+ throw new IllegalArgumentException("Result has neither UUID nor a beacon");
+ })
+ .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue));
+ event = new Event.Results(successful, errors);
+ break;
+ case MESSAGE_NOT_SET:
+ throw new ProtocolViolationException("Message not set");
+ }
+
+ delegate.onNext(event);
+ }
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java
new file mode 100644
index 000000000..453862f5e
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/batch/WeaviateBatchClient.java
@@ -0,0 +1,59 @@
+package io.weaviate.client6.v1.api.collections.batch;
+
+import static java.util.Objects.requireNonNull;
+
+import java.util.OptionalInt;
+
+import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults;
+import io.weaviate.client6.v1.internal.TransportOptions;
+import io.weaviate.client6.v1.internal.grpc.GrpcTransport;
+import io.weaviate.client6.v1.internal.orm.CollectionDescriptor;
+
+public class WeaviateBatchClient {
+ private final CollectionHandleDefaults defaults;
+ private final CollectionDescriptor collectionDescriptor;
+ private final GrpcTransport grpcTransport;
+
+ public WeaviateBatchClient(
+ GrpcTransport grpcTransport,
+ CollectionDescriptor collectionDescriptor,
+ CollectionHandleDefaults defaults) {
+ this.defaults = requireNonNull(defaults, "defaults is null");
+ this.collectionDescriptor = requireNonNull(collectionDescriptor, "collectionDescriptor is null");
+ this.grpcTransport = requireNonNull(grpcTransport, "grpcTransport is null");
+ }
+
+ /** Copy constructor with new defaults. */
+ public WeaviateBatchClient(WeaviateBatchClient c, CollectionHandleDefaults defaults) {
+ this.defaults = requireNonNull(defaults, "defaults is null");
+ this.collectionDescriptor = c.collectionDescriptor;
+ this.grpcTransport = c.grpcTransport;
+ }
+
+ public BatchContext start() {
+ OptionalInt maxSizeBytes = grpcTransport.maxMessageSizeBytes();
+ if (maxSizeBytes.isEmpty()) {
+ throw new IllegalStateException("Server must have grpcMaxMessageSize configured to use server-side batching");
+ }
+
+ StreamFactory streamFactory = new TranslatingStreamFactory(grpcTransport::createStream);
+ BatchContext context = new BatchContext<>(
+ streamFactory,
+ maxSizeBytes.getAsInt(),
+ collectionDescriptor,
+ defaults);
+
+ if (isWeaviateCloudOnGoogleCloud(grpcTransport.host())) {
+ context.scheduleReconnect(GCP_RECONNECT_INTERVAL_SECONDS);
+ }
+
+ context.start();
+ return context;
+ }
+
+ private static final int GCP_RECONNECT_INTERVAL_SECONDS = 160;
+
+ private static boolean isWeaviateCloudOnGoogleCloud(String host) {
+ return TransportOptions.isWeaviateDomain(host) && TransportOptions.isGoogleCloudDomain(host);
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java
index 5237fed26..c27c13aa7 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java
@@ -21,8 +21,8 @@ public static final Endpoint> endpoint(
return SimpleEndpoint.noBody(
request -> "GET",
request -> "/schema/" + collection.collectionName() + "/shards",
- request -> defaults.tenant() != null
- ? Map.of("tenant", defaults.tenant())
+ request -> defaults.tenant().isPresent()
+ ? Map.of("tenant", defaults.tenant().get())
: Collections.emptyMap(),
(statusCode, response) -> (List) JSON.deserialize(response, TypeToken.getParameterized(
List.class, Shard.class)));
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java
index a83652be5..503ec59d1 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/BatchReference.java
@@ -1,5 +1,7 @@
package io.weaviate.client6.v1.api.collections.data;
+import static java.util.Objects.requireNonNull;
+
import java.io.IOException;
import java.util.Arrays;
@@ -9,7 +11,14 @@
import io.weaviate.client6.v1.api.collections.WeaviateObject;
-public record BatchReference(String fromCollection, String fromProperty, String fromUuid, ObjectReference reference) {
+public record BatchReference(String fromCollection, String fromProperty, String fromUuid, ObjectReference target) {
+
+ public BatchReference {
+ requireNonNull(fromCollection, "fromCollection is null");
+ requireNonNull(fromProperty, "fromProperty is null");
+ requireNonNull(fromUuid, "fromUuid is null");
+ requireNonNull(target, "target is null");
+ }
public static BatchReference[] objects(WeaviateObject> fromObject, String fromProperty,
WeaviateObject>... toObjects) {
@@ -39,7 +48,7 @@ public void write(JsonWriter out, BatchReference value) throws IOException {
out.value(ObjectReference.toBeacon(value.fromCollection, value.fromProperty, value.fromUuid));
out.name("to");
- out.value(ObjectReference.toBeacon(value.reference.collection(), value.reference.uuid()));
+ out.value(ObjectReference.toBeacon(value.target.collection(), value.target.uuid()));
out.endObject();
}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java
index 2fff8681a..e60133957 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteManyRequest.java
@@ -29,11 +29,9 @@ public static Rpc Rpc, WeaviateProtoBat
request -> {
var message = WeaviateProtoBatch.BatchObjectsRequest.newBuilder();
- var batch = request.objects.stream().map(obj -> {
- var batchObject = WeaviateProtoBatch.BatchObject.newBuilder();
- buildObject(batchObject, obj, collection, defaults);
- return batchObject.build();
- }).toList();
-
+ var batch = request.objects.stream()
+ .map(obj -> buildObject(obj, collection, defaults))
+ .toList();
message.addAllObjects(batch);
- if (defaults.consistencyLevel() != null) {
- defaults.consistencyLevel().appendTo(message);
+ if (defaults.consistencyLevel().isPresent()) {
+ defaults.consistencyLevel().get().appendTo(message);
}
+ var m = message.build();
+ m.getSerializedSize();
return message.build();
},
response -> {
@@ -92,10 +94,11 @@ public static Rpc, WeaviateProtoBat
() -> WeaviateFutureStub::batchObjects);
}
- public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object,
+ public static WeaviateProtoBatch.BatchObject buildObject(
WeaviateObject insert,
CollectionDescriptor collection,
CollectionHandleDefaults defaults) {
+ var object = WeaviateProtoBatch.BatchObject.newBuilder();
object.setCollection(collection.collectionName());
if (insert.uuid() != null) {
@@ -121,9 +124,7 @@ public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object
}).toList();
object.addAllVectors(vectors);
}
- if (defaults.tenant() != null) {
- object.setTenant(defaults.tenant());
- }
+ defaults.tenant().ifPresent(object::setTenant);
var singleRef = new ArrayList();
var multiRef = new ArrayList();
@@ -158,6 +159,7 @@ public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object
properties.setNonRefProperties(nonRef);
}
object.setProperties(properties);
+ return object.build();
}
@SuppressWarnings("unchecked")
@@ -330,4 +332,20 @@ private static com.google.protobuf.Struct marshalStruct(Map prop
});
return struct.build();
}
+
+ public static WeaviateProtoBatch.BatchReference buildReference(BatchReference reference, Optional tenant) {
+ requireNonNull(reference, "reference is null");
+ WeaviateProtoBatch.BatchReference.Builder builder = WeaviateProtoBatch.BatchReference.newBuilder();
+ builder
+ .setName(reference.fromProperty())
+ .setFromCollection(reference.fromCollection())
+ .setFromUuid(reference.fromUuid())
+ .setToUuid(reference.target().uuid());
+
+ if (reference.target().collection() != null) {
+ builder.setToCollection(reference.target().collection());
+ }
+ tenant.ifPresent(t -> builder.setTenant(t));
+ return builder.build();
+ }
}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java
index 8588eb760..3bfcdf52d 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java
@@ -27,14 +27,14 @@ public static final Endpoint, Wea
return new SimpleEndpoint<>(
request -> "POST",
request -> "/objects/",
- request -> defaults.consistencyLevel() != null
- ? Map.of("consistency_level", defaults.consistencyLevel())
+ request -> defaults.consistencyLevel().isPresent()
+ ? Map.of("consistency_level", defaults.consistencyLevel().get())
: Collections.emptyMap(),
request -> JSON.serialize(
new WeaviateObject<>(
request.object.uuid(),
collection.collectionName(),
- defaults.tenant(),
+ defaults.tenant().get(),
request.object.properties(),
request.object.vectors(),
request.object.createdAt(),
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java
index bb6a0f27e..822c5f54d 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ObjectReference.java
@@ -67,6 +67,10 @@ public static ObjectReference[] collection(String collection, String... uuids) {
.toArray(ObjectReference[]::new);
}
+ public String beacon() {
+ return toBeacon(collection, uuid);
+ }
+
public static String toBeacon(String collection, String uuid) {
return toBeacon(collection, null, uuid);
}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java
index 284688daa..cfb2a4667 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReferenceAddManyRequest.java
@@ -4,6 +4,7 @@
import java.util.List;
import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults;
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch;
import io.weaviate.client6.v1.internal.json.JSON;
import io.weaviate.client6.v1.internal.rest.Endpoint;
import io.weaviate.client6.v1.internal.rest.SimpleEndpoint;
@@ -32,4 +33,7 @@ public static final Endpoint
});
}
+ public static WeaviateProtoBatch.BatchReference buildReference(ObjectReference reference) {
+ return null;
+ }
}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java
index 13a1afacb..45419afff 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/ReplaceObjectRequest.java
@@ -26,14 +26,14 @@ static final Endpoint, Void> end
return SimpleEndpoint.sideEffect(
request -> "PUT",
request -> "/objects/" + collection.collectionName() + "/" + request.object.uuid(),
- request -> defaults.consistencyLevel() != null
- ? Map.of("consistency_level", defaults.consistencyLevel())
+ request -> defaults.consistencyLevel().isPresent()
+ ? Map.of("consistency_level", defaults.consistencyLevel().get())
: Collections.emptyMap(),
request -> JSON.serialize(
new WeaviateObject<>(
request.object.uuid(),
collection.collectionName(),
- defaults.tenant(),
+ defaults.tenant().get(),
request.object.properties(),
request.object.vectors(),
request.object.createdAt(),
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java
index 6157a1cc8..65a1a66f1 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/UpdateObjectRequest.java
@@ -26,14 +26,14 @@ static final Endpoint, Void> endp
return SimpleEndpoint.sideEffect(
request -> "PATCH",
request -> "/objects/" + collection.collectionName() + "/" + request.object.uuid(),
- request -> defaults.consistencyLevel() != null
- ? Map.of("consistency_level", defaults.consistencyLevel())
+ request -> defaults.consistencyLevel().isPresent()
+ ? Map.of("consistency_level", defaults.consistencyLevel().get())
: Collections.emptyMap(),
request -> JSON.serialize(
new WeaviateObject<>(
request.object.uuid(),
collection.collectionName(),
- defaults.tenant(),
+ defaults.tenant().get(),
request.object.properties(),
request.object.vectors(),
request.object.createdAt(),
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java
index c7497ec64..19613dff2 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java
@@ -35,7 +35,7 @@ public WeaviateDataClient(
this.defaults = defaults;
}
- /** Copy constructor that updates the {@link #query} to use new defaults. */
+ /** Copy constructor with new defaults. */
public WeaviateDataClient(WeaviateDataClient c, CollectionHandleDefaults defaults) {
this.restTransport = c.restTransport;
this.grpcTransport = c.grpcTransport;
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java
index 326609d2e..fc688cd41 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/ConsistencyLevel.java
@@ -43,6 +43,10 @@ public final void appendTo(WeaviateProtoBatch.BatchObjectsRequest.Builder req) {
req.setConsistencyLevel(consistencyLevel);
}
+ public final void appendTo(WeaviateProtoBatch.BatchStreamRequest.Start.Builder req) {
+ req.setConsistencyLevel(consistencyLevel);
+ }
+
@Override
public String toString() {
return queryParameter;
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java
index 625dde30d..8bb0db91b 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java
@@ -37,11 +37,9 @@ public static WeaviateProtoSearchGet.SearchRequest marshal(
}
request.operator.appendTo(message);
- if (defaults.tenant() != null) {
- message.setTenant(defaults.tenant());
- }
- if (defaults.consistencyLevel() != null) {
- defaults.consistencyLevel().appendTo(message);
+ defaults.tenant().ifPresent(message::setTenant);
+ if (defaults.consistencyLevel().isPresent()) {
+ defaults.consistencyLevel().get().appendTo(message);
}
if (request.groupBy != null) {
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java
index b1bc7369e..751492591 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java
@@ -3,7 +3,6 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
-import java.util.Map;
import java.util.UUID;
import java.util.stream.Stream;
@@ -92,20 +91,9 @@ static WeaviateObject unmarshalWithReferences(
(map, ref) -> {
var refObjects = ref.getPropertiesList().stream()
.map(property -> {
- var reference = unmarshalWithReferences(
+ return (Reference) unmarshalWithReferences(
property, property.getMetadata(),
CollectionDescriptor.ofMap(property.getTargetCollection()));
- return (Reference) new WeaviateObject<>(
- reference.uuid(),
- reference.collection(),
- // TODO(dyma): we can get tenant from CollectionHandle
- null, // tenant is not returned in the query
- (Map) reference.properties(),
- reference.vectors(),
- reference.createdAt(),
- reference.lastUpdatedAt(),
- reference.queryMetadata(),
- reference.references());
})
.toList();
diff --git a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java
index 897bb28cd..06c0b6c15 100644
--- a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java
+++ b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java
@@ -57,4 +57,20 @@ public H headers() {
public TrustManagerFactory trustManagerFactory() {
return this.trustManagerFactory;
}
+
+ /**
+ * isWeaviateDomain returns true if the host matches weaviate.io,
+ * semi.technology, or weaviate.cloud domain.
+ */
+ public static boolean isWeaviateDomain(String host) {
+ var lower = host.toLowerCase();
+ return lower.contains("weaviate.io") ||
+ lower.contains("semi.technology") ||
+ lower.contains("weaviate.cloud");
+ }
+
+ public static boolean isGoogleCloudDomain(String host) {
+ var lower = host.toLowerCase();
+ return lower.contains("gcp");
+ }
}
diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java
index d12255d22..385808ffc 100644
--- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java
+++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java
@@ -1,6 +1,10 @@
package io.weaviate.client6.v1.internal.grpc;
+import static java.util.Objects.requireNonNull;
+
+import java.util.OptionalInt;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
@@ -16,57 +20,57 @@
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.stub.AbstractStub;
import io.grpc.stub.MetadataUtils;
+import io.grpc.stub.StreamObserver;
import io.weaviate.client6.v1.api.WeaviateApiException;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub;
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply;
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamRequest;
public final class DefaultGrpcTransport implements GrpcTransport {
+ /**
+ * ListenableFuture callbacks are executed
+ * in the same thread they are called from.
+ */
+ private static final Executor FUTURE_CALLBACK_EXECUTOR = Runnable::run;
+
+ private final GrpcChannelOptions transportOptions;
private final ManagedChannel channel;
private final WeaviateBlockingStub blockingStub;
private final WeaviateFutureStub futureStub;
- private final GrpcChannelOptions transportOptions;
-
private TokenCallCredentials callCredentials;
public DefaultGrpcTransport(GrpcChannelOptions transportOptions) {
- this.transportOptions = transportOptions;
- this.channel = buildChannel(transportOptions);
-
- var blockingStub = WeaviateGrpc.newBlockingStub(channel)
- .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));
-
- var futureStub = WeaviateGrpc.newFutureStub(channel)
- .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));
-
- if (transportOptions.maxMessageSize() != null) {
- var max = transportOptions.maxMessageSize();
- blockingStub = blockingStub.withMaxInboundMessageSize(max).withMaxOutboundMessageSize(max);
- futureStub = futureStub.withMaxInboundMessageSize(max).withMaxOutboundMessageSize(max);
- }
+ requireNonNull(transportOptions, "transportOptions is null");
+ this.transportOptions = transportOptions;
if (transportOptions.tokenProvider() != null) {
this.callCredentials = new TokenCallCredentials(transportOptions.tokenProvider());
- blockingStub = blockingStub.withCallCredentials(callCredentials);
- futureStub = futureStub.withCallCredentials(callCredentials);
}
- this.blockingStub = blockingStub;
- this.futureStub = futureStub;
+ this.channel = buildChannel(transportOptions);
+ this.blockingStub = configure(WeaviateGrpc.newBlockingStub(channel));
+ this.futureStub = configure(WeaviateGrpc.newFutureStub(channel));
}
private > StubT applyTimeout(StubT stub, Rpc, ?, ?, ?> rpc) {
if (transportOptions.timeout() == null) {
return stub;
}
- var timeout = rpc.isInsert()
+ int timeout = rpc.isInsert()
? transportOptions.timeout().insertSeconds()
: transportOptions.timeout().querySeconds();
return stub.withDeadlineAfter(timeout, TimeUnit.SECONDS);
}
+ @Override
+ public OptionalInt maxMessageSizeBytes() {
+ return transportOptions.maxMessageSize();
+ }
+
@Override
public ResponseT performRequest(RequestT request,
Rpc rpc) {
@@ -96,7 +100,9 @@ public CompletableFuture perf
* reusing the thread in which the original future is completed.
*/
private static final CompletableFuture toCompletableFuture(ListenableFuture listenable) {
- var completable = new CompletableFuture();
+ requireNonNull(listenable, "listenable is null");
+
+ CompletableFuture completable = new CompletableFuture<>();
Futures.addCallback(listenable, new FutureCallback() {
@Override
@@ -113,13 +119,14 @@ public void onFailure(Throwable t) {
completable.completeExceptionally(t);
}
- }, Runnable::run);
+ }, FUTURE_CALLBACK_EXECUTOR);
return completable;
}
private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) {
- var channel = NettyChannelBuilder.forAddress(transportOptions.host(), transportOptions.port());
+ requireNonNull(transportOptions, "transportOptions is null");
+ NettyChannelBuilder channel = NettyChannelBuilder.forAddress(transportOptions.host(), transportOptions.port());
if (transportOptions.isSecure()) {
channel.useTransportSecurity();
} else {
@@ -140,10 +147,29 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions)
}
channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));
-
return channel.build();
}
+ @Override
+ public StreamObserver createStream(StreamObserver recv) {
+ return configure(WeaviateGrpc.newStub(channel)).batchStream(recv);
+ }
+
+ /** Apply common configuration to a stub. */
+ private > S configure(S stub) {
+ requireNonNull(stub, "stub is null");
+
+ stub = stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));
+ if (transportOptions.maxMessageSize().isPresent()) {
+ int max = transportOptions.maxMessageSize().getAsInt();
+ stub = stub.withMaxInboundMessageSize(max).withMaxOutboundMessageSize(max);
+ }
+ if (callCredentials != null) {
+ stub = stub.withCallCredentials(callCredentials);
+ }
+ return stub;
+ }
+
@Override
public void close() throws Exception {
channel.shutdown();
@@ -151,4 +177,9 @@ public void close() throws Exception {
callCredentials.close();
}
}
+
+ @Override
+ public String host() {
+ return transportOptions.host();
+ }
}
diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java
index 5e4453d7f..96366cb5f 100644
--- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java
+++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java
@@ -1,6 +1,7 @@
package io.weaviate.client6.v1.internal.grpc;
import java.util.Map;
+import java.util.OptionalInt;
import javax.net.ssl.TrustManagerFactory;
@@ -10,7 +11,7 @@
import io.weaviate.client6.v1.internal.TransportOptions;
public class GrpcChannelOptions extends TransportOptions {
- private final Integer maxMessageSize;
+ private final OptionalInt maxMessageSize;
public GrpcChannelOptions(String scheme, String host, int port, Map headers,
TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout) {
@@ -18,17 +19,18 @@ public GrpcChannelOptions(String scheme, String host, int port, Map ResponseT performRequest(RequestT request,
- Rpc rpc);
+ ResponseT performRequest(RequestT request,
+ Rpc rpc);
+
+ CompletableFuture performRequestAsync(RequestT request,
+ Rpc rpc);
+
+ /**
+ * Create stream for batch insertion.
+ *
+ * @apiNote Batch insertion is presently the only operation performed over a
+ * StreamStream connection, which is why we do not parametrize this
+ * method.
+ */
+ StreamObserver createStream(
+ StreamObserver recv);
+
+ String host();
- CompletableFuture performRequestAsync(RequestT request,
- Rpc rpc);
+ /**
+ * Maximum inbound/outbound message size supported by the underlying channel.
+ */
+ OptionalInt maxMessageSizeBytes();
}
diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java
index 4b420b3fd..95d7cd126 100644
--- a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java
+++ b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleDefaultsTest.java
@@ -24,7 +24,7 @@ public class CollectionHandleDefaultsTest {
/** All defaults are {@code null} if none were set. */
@Test
public void test_defaults() {
- Assertions.assertThat(HANDLE_NONE.consistencyLevel()).as("default ConsistencyLevel").isNull();
+ Assertions.assertThat(HANDLE_NONE.consistencyLevel()).as("default ConsistencyLevel").isEmpty();
Assertions.assertThat(HANDLE_NONE.tenant()).as("default tenant").isNull();
}
@@ -35,8 +35,8 @@ public void test_defaults() {
@Test
public void test_withConsistencyLevel() {
var handle = HANDLE_NONE.withConsistencyLevel(ConsistencyLevel.QUORUM);
- Assertions.assertThat(handle.consistencyLevel()).isEqualTo(ConsistencyLevel.QUORUM);
- Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isNull();
+ Assertions.assertThat(handle.consistencyLevel()).get().isEqualTo(ConsistencyLevel.QUORUM);
+ Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isEmpty();
}
/**
@@ -46,8 +46,8 @@ public void test_withConsistencyLevel() {
@Test
public void test_withConsistencyLevel_async() {
var handle = HANDLE_NONE_ASYNC.withConsistencyLevel(ConsistencyLevel.QUORUM);
- Assertions.assertThat(handle.consistencyLevel()).isEqualTo(ConsistencyLevel.QUORUM);
- Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isNull();
+ Assertions.assertThat(handle.consistencyLevel()).get().isEqualTo(ConsistencyLevel.QUORUM);
+ Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isEmpty();
}
/**
@@ -58,7 +58,7 @@ public void test_withConsistencyLevel_async() {
public void test_withTenant() {
var handle = HANDLE_NONE.withTenant("john_doe");
Assertions.assertThat(handle.tenant()).isEqualTo("john_doe");
- Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isNull();
+ Assertions.assertThat(HANDLE_NONE.consistencyLevel()).isEmpty();
}
/**
@@ -69,6 +69,6 @@ public void test_withTenant() {
public void test_withTenant_async() {
var handle = HANDLE_NONE_ASYNC.withTenant("john_doe");
Assertions.assertThat(handle.tenant()).isEqualTo("john_doe");
- Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isNull();
+ Assertions.assertThat(HANDLE_NONE_ASYNC.consistencyLevel()).isEmpty();
}
}
diff --git a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java
index 9cf1e99d9..50286503a 100644
--- a/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java
+++ b/src/test/java/io/weaviate/client6/v1/api/collections/CollectionHandleTest.java
@@ -131,18 +131,18 @@ public void test_collectionHandleDefaults_rest(String __,
rest.assertNext((method, requestUrl, body, query) -> {
switch (clLoc) {
case QUERY:
- Assertions.assertThat(query).containsEntry("consistency_level", defaults.consistencyLevel());
+ Assertions.assertThat(query).containsEntry("consistency_level", defaults.consistencyLevel().get());
break;
case BODY:
- assertJsonHasValue(body, "consistency_level", defaults.consistencyLevel());
+ assertJsonHasValue(body, "consistency_level", defaults.consistencyLevel().get());
}
switch (tenantLoc) {
case QUERY:
- Assertions.assertThat(query).containsEntry("tenant", defaults.tenant());
+ Assertions.assertThat(query).containsEntry("tenant", defaults.tenant().get());
break;
case BODY:
- assertJsonHasValue(body, "tenant", defaults.tenant());
+ assertJsonHasValue(body, "tenant", defaults.tenant().get());
}
});
}
@@ -219,7 +219,7 @@ public void test_defaultTenant_getShards() throws IOException {
// Assert
rest.assertNext((method, requestUrl, body, query) -> {
- Assertions.assertThat(query).containsEntry("tenant", defaults.tenant());
+ Assertions.assertThat(query).containsEntry("tenant", defaults.tenant().get());
});
}
diff --git a/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java b/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java
index ebea2fea7..98af3d227 100644
--- a/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java
+++ b/src/test/java/io/weaviate/testutil/transport/MockGrpcTransport.java
@@ -3,16 +3,21 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import java.util.OptionalInt;
import java.util.concurrent.CompletableFuture;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.util.JsonFormat;
+import io.grpc.stub.StreamObserver;
import io.weaviate.client6.v1.internal.grpc.GrpcTransport;
import io.weaviate.client6.v1.internal.grpc.Rpc;
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply;
+import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamRequest;
public class MockGrpcTransport implements GrpcTransport {
+ private final String host = "example.com";
@FunctionalInterface
public interface AssertFunction {
@@ -57,4 +62,21 @@ public CompletableFuture perf
@Override
public void close() throws IOException {
}
+
+ @Override
+ public StreamObserver createStream(StreamObserver recv) {
+ // TODO(dyma): implement for tests
+ throw new UnsupportedOperationException("Unimplemented method 'createStream'");
+ }
+
+ @Override
+ public OptionalInt maxMessageSizeBytes() {
+ // TODO(dyma): implement for tests
+ throw new UnsupportedOperationException("Unimplemented method 'maxMessageSizeBytes'");
+ }
+
+ @Override
+ public String host() {
+ return host;
+ }
}