From fd50c7c636e2c8456d540ad54e4899342143ef9a Mon Sep 17 00:00:00 2001 From: Vova Kolmakov Date: Wed, 22 Apr 2026 13:42:31 +0700 Subject: [PATCH 1/3] build: add spotless, checkstyle, and enforcer plugin configuration --- .editorconfig | 19 ++++++ checkstyle.xml | 174 +++++++++++++++++++++++++++++++++++++++++++++++++ pom.xml | 90 +++++++++++++++++++++++++ 3 files changed, 283 insertions(+) create mode 100644 .editorconfig create mode 100644 checkstyle.xml diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..0a2a8e1 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,19 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[*.java] +indent_style = space +indent_size = 4 +max_line_length = 100 + +[*.{xml,yml,yaml}] +indent_style = space +indent_size = 2 + +[Makefile] +indent_style = tab diff --git a/checkstyle.xml b/checkstyle.xml new file mode 100644 index 0000000..a689896 --- /dev/null +++ b/checkstyle.xml @@ -0,0 +1,174 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/pom.xml b/pom.xml index d78486c..23f5df6 100644 --- a/pom.xml +++ b/pom.xml @@ -27,6 +27,28 @@ 2.20.0 5.3.1 3.24.2 + + + 2.43.0 + 1.19.2 + package + + /* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + 3.3.1 + 3.4.1 @@ -345,6 +367,74 @@ + + + + com.diffplug.spotless + spotless-maven-plugin + ${spotless.version} + + + true + + + + src/main/java/**/*.java + src/test/java/**/*.java + + + ${spotless.java.googlejavaformat.version} + + + + org.apache.flink,,javax,java,\# + + + + + ${spotless.license.header} + ${spotless.delimiter} + + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + ${maven-checkstyle-plugin.version} + + checkstyle.xml + true + true + warning + false + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + ${maven-enforcer-plugin.version} + + + enforce-versions + validate + + enforce + + + + + 3.6.3 + + + + + + + From df2272caaac25bf387d2484b8f04ade923a90b2c Mon Sep 17 00:00:00 2001 From: Vova Kolmakov Date: Wed, 22 Apr 2026 14:00:16 +0700 Subject: [PATCH 2/3] chore: apply initial spotless formatting --- .editorconfig | 2 +- pom.xml | 22 +- .../connector/lance/LanceAggregateSource.java | 428 +++-- .../connector/lance/LanceIndexBuilder.java | 736 ++++---- .../connector/lance/LanceInputFormat.java | 506 +++--- .../flink/connector/lance/LanceSink.java | 489 +++--- .../flink/connector/lance/LanceSource.java | 616 ++++--- .../flink/connector/lance/LanceSplit.java | 180 +- .../connector/lance/LanceVectorSearch.java | 677 ++++---- .../lance/aggregate/AggregateExecutor.java | 876 +++++----- .../lance/aggregate/AggregateInfo.java | 388 ++--- .../connector/lance/config/LanceOptions.java | 1519 ++++++++-------- .../lance/converter/LanceTypeConverter.java | 712 ++++---- .../lance/converter/RowDataConverter.java | 1168 ++++++------- .../connector/lance/table/LanceCatalog.java | 1489 ++++++++-------- .../lance/table/LanceCatalogFactory.java | 250 ++- .../lance/table/LanceDynamicTableFactory.java | 379 ++-- .../lance/table/LanceDynamicTableSink.java | 92 +- .../lance/table/LanceDynamicTableSource.java | 825 +++++---- .../table/LanceVectorSearchFunction.java | 516 +++--- .../connector/lance/LanceConnectorITCase.java | 712 ++++---- .../lance/LanceIndexBuilderTest.java | 517 +++--- .../flink/connector/lance/LanceSinkTest.java | 335 ++-- .../connector/lance/LanceSourceTest.java | 288 ++-- .../lance/LanceTypeConverterTest.java | 542 +++--- .../lance/LanceVectorSearchTest.java | 435 ++--- .../aggregate/AggregateExecutorTest.java | 892 +++++----- .../lance/aggregate/AggregateInfoTest.java | 633 ++++--- .../connector/lance/table/FlinkSqlDemo.java | 1531 +++++++++-------- .../table/LanceAggregatePushDownTest.java | 591 +++---- .../lance/table/LanceCatalogS3Test.java | 1057 ++++++------ .../table/LanceReadOptimizationsTest.java | 856 +++++---- .../connector/lance/table/LanceSqlITCase.java | 590 +++---- 33 files changed, 10254 insertions(+), 10595 deletions(-) diff --git a/.editorconfig b/.editorconfig index 0a2a8e1..884396c 100644 --- a/.editorconfig +++ b/.editorconfig @@ -8,7 +8,7 @@ trim_trailing_whitespace = true [*.java] indent_style = space -indent_size = 4 +indent_size = 2 max_line_length = 100 [*.{xml,yml,yaml}] diff --git a/pom.xml b/pom.xml index 23f5df6..95ada5f 100644 --- a/pom.xml +++ b/pom.xml @@ -368,7 +368,7 @@ - + com.diffplug.spotless spotless-maven-plugin @@ -396,9 +396,18 @@ ${spotless.delimiter} + + + spotless-check + validate + + check + + + - + org.apache.maven.plugins maven-checkstyle-plugin @@ -410,6 +419,15 @@ warning false + + + checkstyle-check + validate + + check + + + diff --git a/src/main/java/org/apache/flink/connector/lance/LanceAggregateSource.java b/src/main/java/org/apache/flink/connector/lance/LanceAggregateSource.java index ed512e5..6946ed3 100644 --- a/src/main/java/org/apache/flink/connector/lance/LanceAggregateSource.java +++ b/src/main/java/org/apache/flink/connector/lance/LanceAggregateSource.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.configuration.Configuration; @@ -41,16 +36,17 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.nio.file.Paths; import java.util.Arrays; import java.util.List; /** * Lance data source with aggregate push-down support. * - *

Executes aggregate computation at the data source, supporting COUNT, SUM, AVG, MIN, MAX and other aggregate functions. + *

Executes aggregate computation at the data source, supporting COUNT, SUM, AVG, MIN, MAX and + * other aggregate functions. * *

Usage example: + * *

{@code
  * AggregateInfo aggInfo = AggregateInfo.builder()
  *     .addCountStar("cnt")
@@ -63,263 +59,249 @@
  */
 public class LanceAggregateSource extends RichParallelSourceFunction {
 
-    private static final long serialVersionUID = 1L;
-    private static final Logger LOG = LoggerFactory.getLogger(LanceAggregateSource.class);
-
-    private final LanceOptions options;
-    private final RowType sourceRowType;
-    private final AggregateInfo aggregateInfo;
-    private final String[] selectedColumns;
-
-    private transient volatile boolean running;
-    private transient BufferAllocator allocator;
-    private transient Dataset dataset;
-    private transient RowDataConverter converter;
-    private transient AggregateExecutor aggregateExecutor;
-
-    /**
-     * Create LanceAggregateSource
-     *
-     * @param options Lance configuration options
-     * @param sourceRowType RowType of source table
-     * @param aggregateInfo Aggregate information
-     */
-    public LanceAggregateSource(LanceOptions options, RowType sourceRowType, AggregateInfo aggregateInfo) {
-        this.options = options;
-        this.sourceRowType = sourceRowType;
-        this.aggregateInfo = aggregateInfo;
-
-        // Calculate columns to read
-        List requiredColumns = aggregateInfo.getRequiredColumns();
-        this.selectedColumns = requiredColumns.isEmpty() ? null : requiredColumns.toArray(new String[0]);
+  private static final long serialVersionUID = 1L;
+  private static final Logger LOG = LoggerFactory.getLogger(LanceAggregateSource.class);
+
+  private final LanceOptions options;
+  private final RowType sourceRowType;
+  private final AggregateInfo aggregateInfo;
+  private final String[] selectedColumns;
+
+  private transient volatile boolean running;
+  private transient BufferAllocator allocator;
+  private transient Dataset dataset;
+  private transient RowDataConverter converter;
+  private transient AggregateExecutor aggregateExecutor;
+
+  /**
+   * Create LanceAggregateSource
+   *
+   * @param options Lance configuration options
+   * @param sourceRowType RowType of source table
+   * @param aggregateInfo Aggregate information
+   */
+  public LanceAggregateSource(
+      LanceOptions options, RowType sourceRowType, AggregateInfo aggregateInfo) {
+    this.options = options;
+    this.sourceRowType = sourceRowType;
+    this.aggregateInfo = aggregateInfo;
+
+    // Calculate columns to read
+    List requiredColumns = aggregateInfo.getRequiredColumns();
+    this.selectedColumns =
+        requiredColumns.isEmpty() ? null : requiredColumns.toArray(new String[0]);
+  }
+
+  @Override
+  public void open(Configuration parameters) throws Exception {
+    super.open(parameters);
+
+    LOG.info("Opening Lance aggregate data source: {}", options.getPath());
+    LOG.info("Aggregate info: {}", aggregateInfo);
+
+    this.running = true;
+    this.allocator = new RootAllocator(Long.MAX_VALUE);
+
+    // Open Lance dataset
+    String datasetPath = options.getPath();
+    if (datasetPath == null || datasetPath.isEmpty()) {
+      throw new IllegalArgumentException("Lance dataset path cannot be empty");
     }
 
-    @Override
-    public void open(Configuration parameters) throws Exception {
-        super.open(parameters);
+    try {
+      this.dataset = Dataset.open(datasetPath, allocator);
+    } catch (Exception e) {
+      throw new IOException("Failed to open Lance dataset: " + datasetPath, e);
+    }
 
-        LOG.info("Opening Lance aggregate data source: {}", options.getPath());
-        LOG.info("Aggregate info: {}", aggregateInfo);
+    // Initialize RowDataConverter (using source table Schema)
+    RowType actualRowType = this.sourceRowType;
+    if (actualRowType == null) {
+      Schema arrowSchema = dataset.getSchema();
+      actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
+    }
+    this.converter = new RowDataConverter(actualRowType);
 
-        this.running = true;
-        this.allocator = new RootAllocator(Long.MAX_VALUE);
+    // Initialize aggregate executor
+    this.aggregateExecutor = new AggregateExecutor(aggregateInfo, actualRowType);
+    this.aggregateExecutor.init();
 
-        // Open Lance dataset
-        String datasetPath = options.getPath();
-        if (datasetPath == null || datasetPath.isEmpty()) {
-            throw new IllegalArgumentException("Lance dataset path cannot be empty");
-        }
+    LOG.info("Lance aggregate data source opened");
+  }
 
-        try {
-            this.dataset = Dataset.open(datasetPath, allocator);
-        } catch (Exception e) {
-            throw new IOException("Failed to open Lance dataset: " + datasetPath, e);
-        }
-
-        // Initialize RowDataConverter (using source table Schema)
-        RowType actualRowType = this.sourceRowType;
-        if (actualRowType == null) {
-            Schema arrowSchema = dataset.getSchema();
-            actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
-        }
-        this.converter = new RowDataConverter(actualRowType);
+  @Override
+  public void run(SourceContext ctx) throws Exception {
+    LOG.info("Starting aggregate read from Lance dataset: {}", options.getPath());
 
-        // Initialize aggregate executor
-        this.aggregateExecutor = new AggregateExecutor(aggregateInfo, actualRowType);
-        this.aggregateExecutor.init();
+    int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+    int numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
 
-        LOG.info("Lance aggregate data source opened");
+    // Aggregate operation only executes on subtask 0 to avoid duplicate aggregation
+    if (subtaskIndex != 0) {
+      LOG.info("Subtask {} skipped (only subtask 0 executes in aggregate mode)", subtaskIndex);
+      return;
     }
 
-    @Override
-    public void run(SourceContext ctx) throws Exception {
-        LOG.info("Starting aggregate read from Lance dataset: {}", options.getPath());
-
-        int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
-        int numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
+    String filter = options.getReadFilter();
 
-        // Aggregate operation only executes on subtask 0 to avoid duplicate aggregation
-        if (subtaskIndex != 0) {
-            LOG.info("Subtask {} skipped (only subtask 0 executes in aggregate mode)", subtaskIndex);
-            return;
-        }
+    // Read all data and perform aggregation
+    if (filter != null && !filter.isEmpty()) {
+      readAndAggregateWithFilter(ctx);
+    } else {
+      readAndAggregateAll(ctx);
+    }
 
-        String filter = options.getReadFilter();
+    LOG.info("Lance aggregate data source read completed");
+  }
 
-        // Read all data and perform aggregation
-        if (filter != null && !filter.isEmpty()) {
-            readAndAggregateWithFilter(ctx);
-        } else {
-            readAndAggregateAll(ctx);
-        }
+  /** Aggregate read with filter condition */
+  private void readAndAggregateWithFilter(SourceContext ctx) throws Exception {
+    ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
+    scanOptionsBuilder.batchSize(options.getReadBatchSize());
 
-        LOG.info("Lance aggregate data source read completed");
+    if (selectedColumns != null && selectedColumns.length > 0) {
+      scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
     }
 
-    /**
-     * Aggregate read with filter condition
-     */
-    private void readAndAggregateWithFilter(SourceContext ctx) throws Exception {
-        ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
-        scanOptionsBuilder.batchSize(options.getReadBatchSize());
-
-        if (selectedColumns != null && selectedColumns.length > 0) {
-            scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
-        }
-
-        String filter = options.getReadFilter();
-        if (filter != null && !filter.isEmpty()) {
-            LOG.info("Applying filter condition: {}", filter);
-            scanOptionsBuilder.filter(filter);
-        }
+    String filter = options.getReadFilter();
+    if (filter != null && !filter.isEmpty()) {
+      LOG.info("Applying filter condition: {}", filter);
+      scanOptionsBuilder.filter(filter);
+    }
 
-        ScanOptions scanOptions = scanOptionsBuilder.build();
+    ScanOptions scanOptions = scanOptionsBuilder.build();
 
-        // Phase 1: Read data and accumulate aggregation
-        try (LanceScanner scanner = dataset.newScan(scanOptions)) {
-            try (ArrowReader reader = scanner.scanBatches()) {
-                while (reader.loadNextBatch() && running) {
-                    VectorSchemaRoot root = reader.getVectorSchemaRoot();
-                    List rows = converter.toRowDataList(root);
+    // Phase 1: Read data and accumulate aggregation
+    try (LanceScanner scanner = dataset.newScan(scanOptions)) {
+      try (ArrowReader reader = scanner.scanBatches()) {
+        while (reader.loadNextBatch() && running) {
+          VectorSchemaRoot root = reader.getVectorSchemaRoot();
+          List rows = converter.toRowDataList(root);
 
-                    for (RowData row : rows) {
-                        aggregateExecutor.accumulate(row);
-                    }
-                }
-            }
+          for (RowData row : rows) {
+            aggregateExecutor.accumulate(row);
+          }
         }
-
-        // Phase 2: Output aggregate results
-        outputAggregateResults(ctx);
+      }
     }
 
-    /**
-     * Read all data and aggregate (without filter condition)
-     */
-    private void readAndAggregateAll(SourceContext ctx) throws Exception {
-        List fragments = dataset.getFragments();
-        LOG.info("Dataset has {} Fragments", fragments.size());
-
-        // Phase 1: Read all Fragments and accumulate aggregation
-        for (Fragment fragment : fragments) {
-            if (!running) {
-                break;
-            }
-            readAndAggregateFragment(fragment);
-        }
-
-        // Phase 2: Output aggregate results
-        outputAggregateResults(ctx);
+    // Phase 2: Output aggregate results
+    outputAggregateResults(ctx);
+  }
+
+  /** Read all data and aggregate (without filter condition) */
+  private void readAndAggregateAll(SourceContext ctx) throws Exception {
+    List fragments = dataset.getFragments();
+    LOG.info("Dataset has {} Fragments", fragments.size());
+
+    // Phase 1: Read all Fragments and accumulate aggregation
+    for (Fragment fragment : fragments) {
+      if (!running) {
+        break;
+      }
+      readAndAggregateFragment(fragment);
     }
 
-    /**
-     * Read single Fragment and accumulate aggregation
-     */
-    private void readAndAggregateFragment(Fragment fragment) throws Exception {
-        LOG.debug("Reading Fragment: {}", fragment.getId());
+    // Phase 2: Output aggregate results
+    outputAggregateResults(ctx);
+  }
 
-        ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
-        scanOptionsBuilder.batchSize(options.getReadBatchSize());
+  /** Read single Fragment and accumulate aggregation */
+  private void readAndAggregateFragment(Fragment fragment) throws Exception {
+    LOG.debug("Reading Fragment: {}", fragment.getId());
 
-        if (selectedColumns != null && selectedColumns.length > 0) {
-            scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
-        }
+    ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
+    scanOptionsBuilder.batchSize(options.getReadBatchSize());
 
-        ScanOptions scanOptions = scanOptionsBuilder.build();
+    if (selectedColumns != null && selectedColumns.length > 0) {
+      scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
+    }
 
-        try (LanceScanner scanner = fragment.newScan(scanOptions)) {
-            try (ArrowReader reader = scanner.scanBatches()) {
-                while (reader.loadNextBatch() && running) {
-                    VectorSchemaRoot root = reader.getVectorSchemaRoot();
-                    List rows = converter.toRowDataList(root);
+    ScanOptions scanOptions = scanOptionsBuilder.build();
 
-                    for (RowData row : rows) {
-                        aggregateExecutor.accumulate(row);
-                    }
-                }
-            }
-        }
-    }
+    try (LanceScanner scanner = fragment.newScan(scanOptions)) {
+      try (ArrowReader reader = scanner.scanBatches()) {
+        while (reader.loadNextBatch() && running) {
+          VectorSchemaRoot root = reader.getVectorSchemaRoot();
+          List rows = converter.toRowDataList(root);
 
-    /**
-     * Output aggregate results
-     */
-    private void outputAggregateResults(SourceContext ctx) {
-        List results = aggregateExecutor.getResults();
-        LOG.info("Aggregation completed, {} result rows", results.size());
-
-        synchronized (ctx.getCheckpointLock()) {
-            for (RowData result : results) {
-                ctx.collect(result);
-            }
+          for (RowData row : rows) {
+            aggregateExecutor.accumulate(row);
+          }
         }
+      }
     }
+  }
+
+  /** Output aggregate results */
+  private void outputAggregateResults(SourceContext ctx) {
+    List results = aggregateExecutor.getResults();
+    LOG.info("Aggregation completed, {} result rows", results.size());
 
-    @Override
-    public void cancel() {
-        LOG.info("Cancelling Lance aggregate data source");
-        this.running = false;
+    synchronized (ctx.getCheckpointLock()) {
+      for (RowData result : results) {
+        ctx.collect(result);
+      }
     }
+  }
 
-    @Override
-    public void close() throws Exception {
-        LOG.info("Closing Lance aggregate data source");
+  @Override
+  public void cancel() {
+    LOG.info("Cancelling Lance aggregate data source");
+    this.running = false;
+  }
 
-        this.running = false;
+  @Override
+  public void close() throws Exception {
+    LOG.info("Closing Lance aggregate data source");
 
-        if (aggregateExecutor != null) {
-            aggregateExecutor.reset();
-        }
+    this.running = false;
 
-        if (dataset != null) {
-            try {
-                dataset.close();
-            } catch (Exception e) {
-                LOG.warn("Error closing Lance dataset", e);
-            }
-            dataset = null;
-        }
-
-        if (allocator != null) {
-            try {
-                allocator.close();
-            } catch (Exception e) {
-                LOG.warn("Error closing memory allocator", e);
-            }
-            allocator = null;
-        }
-
-        super.close();
+    if (aggregateExecutor != null) {
+      aggregateExecutor.reset();
     }
 
-    /**
-     * Get aggregate information
-     */
-    public AggregateInfo getAggregateInfo() {
-        return aggregateInfo;
+    if (dataset != null) {
+      try {
+        dataset.close();
+      } catch (Exception e) {
+        LOG.warn("Error closing Lance dataset", e);
+      }
+      dataset = null;
     }
 
-    /**
-     * Get configuration options
-     */
-    public LanceOptions getOptions() {
-        return options;
+    if (allocator != null) {
+      try {
+        allocator.close();
+      } catch (Exception e) {
+        LOG.warn("Error closing memory allocator", e);
+      }
+      allocator = null;
     }
 
-    /**
-     * Get source table RowType
-     */
-    public RowType getSourceRowType() {
-        return sourceRowType;
-    }
+    super.close();
+  }
 
-    /**
-     * Get aggregate result RowType
-     */
-    public RowType getResultRowType() {
-        if (aggregateExecutor != null) {
-            return aggregateExecutor.buildResultRowType();
-        }
-        return null;
+  /** Get aggregate information */
+  public AggregateInfo getAggregateInfo() {
+    return aggregateInfo;
+  }
+
+  /** Get configuration options */
+  public LanceOptions getOptions() {
+    return options;
+  }
+
+  /** Get source table RowType */
+  public RowType getSourceRowType() {
+    return sourceRowType;
+  }
+
+  /** Get aggregate result RowType */
+  public RowType getResultRowType() {
+    if (aggregateExecutor != null) {
+      return aggregateExecutor.buildResultRowType();
     }
+    return null;
+  }
 }
diff --git a/src/main/java/org/apache/flink/connector/lance/LanceIndexBuilder.java b/src/main/java/org/apache/flink/connector/lance/LanceIndexBuilder.java
index 3f93897..59b780e 100644
--- a/src/main/java/org/apache/flink/connector/lance/LanceIndexBuilder.java
+++ b/src/main/java/org/apache/flink/connector/lance/LanceIndexBuilder.java
@@ -1,11 +1,7 @@
 /*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
@@ -15,7 +11,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.flink.connector.lance;
 
 import org.apache.flink.connector.lance.config.LanceOptions;
@@ -41,10 +36,11 @@
 
 /**
  * Lance vector index builder.
- * 
+ *
  * 

Supports building IVF_PQ, IVF_HNSW_PQ, and IVF_FLAT vector indices. - * + * *

Usage example: + * *

{@code
  * LanceIndexBuilder builder = LanceIndexBuilder.builder()
  *     .datasetPath("/path/to/dataset")
@@ -53,384 +49,390 @@
  *     .numPartitions(256)
  *     .numSubVectors(16)
  *     .build();
- * 
+ *
  * IndexBuildResult result = builder.buildIndex();
  * }
*/ public class LanceIndexBuilder implements Closeable, Serializable { - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(LanceIndexBuilder.class); + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(LanceIndexBuilder.class); + + private final String datasetPath; + private final String columnName; + private final LanceOptions.IndexType indexType; + private final LanceOptions.MetricType metricType; + private final int numPartitions; + private final Integer numSubVectors; + private final int numBits; + private final int maxLevel; + private final int m; + private final int efConstruction; + private final boolean replace; + + private transient BufferAllocator allocator; + private transient Dataset dataset; + + private LanceIndexBuilder(Builder builder) { + this.datasetPath = builder.datasetPath; + this.columnName = builder.columnName; + this.indexType = builder.indexType; + this.metricType = builder.metricType; + this.numPartitions = builder.numPartitions; + this.numSubVectors = builder.numSubVectors; + this.numBits = builder.numBits; + this.maxLevel = builder.maxLevel; + this.m = builder.m; + this.efConstruction = builder.efConstruction; + this.replace = builder.replace; + } + + /** + * Build vector index + * + * @return Index build result + */ + public IndexBuildResult buildIndex() throws IOException { + LOG.info( + "Starting to build vector index, type: {}, column: {}, dataset: {}", + indexType, + columnName, + datasetPath); + + long startTime = System.currentTimeMillis(); + + try { + // Initialize resources + this.allocator = new RootAllocator(Long.MAX_VALUE); + this.dataset = Dataset.open(datasetPath, allocator); + + // Validate column exists + validateColumn(); + + // Get distance metric type + DistanceType distanceType = toDistanceType(metricType); + + // Build IVF parameters + IvfBuildParams ivfParams = + new IvfBuildParams.Builder().setNumPartitions(numPartitions).build(); + + // Build index based on index type + IndexType lanceIndexType; + IndexParams indexParams; + + switch (indexType) { + case IVF_PQ: + lanceIndexType = IndexType.IVF_PQ; + PQBuildParams pqParams = + new PQBuildParams.Builder() + .setNumSubVectors(numSubVectors != null ? numSubVectors : 16) + .setNumBits(numBits) + .build(); + VectorIndexParams ivfPqParams = + VectorIndexParams.withIvfPqParams(distanceType, ivfParams, pqParams); + indexParams = + new IndexParams.Builder() + .setDistanceType(distanceType) + .setVectorIndexParams(ivfPqParams) + .build(); + break; + + case IVF_HNSW: + lanceIndexType = IndexType.IVF_HNSW_PQ; + HnswBuildParams hnswParams = + new HnswBuildParams.Builder() + .setMaxLevel((short) maxLevel) + .setM(m) + .setEfConstruction(efConstruction) + .build(); + PQBuildParams hnswPqParams = + new PQBuildParams.Builder() + .setNumSubVectors(numSubVectors != null ? numSubVectors : 16) + .setNumBits(numBits) + .build(); + VectorIndexParams ivfHnswParams = + VectorIndexParams.withIvfHnswPqParams( + distanceType, ivfParams, hnswParams, hnswPqParams); + indexParams = + new IndexParams.Builder() + .setDistanceType(distanceType) + .setVectorIndexParams(ivfHnswParams) + .build(); + break; + + case IVF_FLAT: + lanceIndexType = IndexType.IVF_FLAT; + VectorIndexParams ivfFlatParams = VectorIndexParams.ivfFlat(numPartitions, distanceType); + indexParams = + new IndexParams.Builder() + .setDistanceType(distanceType) + .setVectorIndexParams(ivfFlatParams) + .build(); + break; + + default: + throw new IllegalArgumentException("Unsupported index type: " + indexType); + } + + // Create index + dataset.createIndex( + Collections.singletonList(columnName), + lanceIndexType, + Optional.empty(), // Index name, use default + indexParams, + replace); + + long endTime = System.currentTimeMillis(); + long duration = endTime - startTime; + + LOG.info("Vector index build completed, duration: {} ms", duration); + + return new IndexBuildResult(true, indexType, columnName, datasetPath, duration, null); + } catch (Exception e) { + LOG.error("Failed to build vector index", e); + return new IndexBuildResult( + false, + indexType, + columnName, + datasetPath, + System.currentTimeMillis() - startTime, + e.getMessage()); + } + } - private final String datasetPath; - private final String columnName; - private final LanceOptions.IndexType indexType; - private final LanceOptions.MetricType metricType; - private final int numPartitions; - private final Integer numSubVectors; - private final int numBits; - private final int maxLevel; - private final int m; - private final int efConstruction; - private final boolean replace; - - private transient BufferAllocator allocator; - private transient Dataset dataset; - - private LanceIndexBuilder(Builder builder) { - this.datasetPath = builder.datasetPath; - this.columnName = builder.columnName; - this.indexType = builder.indexType; - this.metricType = builder.metricType; - this.numPartitions = builder.numPartitions; - this.numSubVectors = builder.numSubVectors; - this.numBits = builder.numBits; - this.maxLevel = builder.maxLevel; - this.m = builder.m; - this.efConstruction = builder.efConstruction; - this.replace = builder.replace; + /** Validate vector column exists */ + private void validateColumn() throws IOException { + // Check if column exists in Schema + boolean columnExists = + dataset.getSchema().getFields().stream() + .anyMatch(field -> field.getName().equals(columnName)); + + if (!columnExists) { + throw new IOException("Vector column does not exist: " + columnName); + } + } + + /** Convert distance metric type */ + private DistanceType toDistanceType(LanceOptions.MetricType metricType) { + switch (metricType) { + case L2: + return DistanceType.L2; + case COSINE: + return DistanceType.Cosine; + case DOT: + return DistanceType.Dot; + default: + return DistanceType.L2; + } + } + + @Override + public void close() throws IOException { + if (dataset != null) { + try { + dataset.close(); + } catch (Exception e) { + LOG.warn("Failed to close dataset", e); + } + dataset = null; } - /** - * Build vector index - * - * @return Index build result - */ - public IndexBuildResult buildIndex() throws IOException { - LOG.info("Starting to build vector index, type: {}, column: {}, dataset: {}", - indexType, columnName, datasetPath); - - long startTime = System.currentTimeMillis(); - - try { - // Initialize resources - this.allocator = new RootAllocator(Long.MAX_VALUE); - this.dataset = Dataset.open(datasetPath, allocator); - - // Validate column exists - validateColumn(); - - // Get distance metric type - DistanceType distanceType = toDistanceType(metricType); - - // Build IVF parameters - IvfBuildParams ivfParams = new IvfBuildParams.Builder() - .setNumPartitions(numPartitions) - .build(); - - // Build index based on index type - IndexType lanceIndexType; - IndexParams indexParams; - - switch (indexType) { - case IVF_PQ: - lanceIndexType = IndexType.IVF_PQ; - PQBuildParams pqParams = new PQBuildParams.Builder() - .setNumSubVectors(numSubVectors != null ? numSubVectors : 16) - .setNumBits(numBits) - .build(); - VectorIndexParams ivfPqParams = VectorIndexParams.withIvfPqParams( - distanceType, ivfParams, pqParams); - indexParams = new IndexParams.Builder() - .setDistanceType(distanceType) - .setVectorIndexParams(ivfPqParams) - .build(); - break; - - case IVF_HNSW: - lanceIndexType = IndexType.IVF_HNSW_PQ; - HnswBuildParams hnswParams = new HnswBuildParams.Builder() - .setMaxLevel((short) maxLevel) - .setM(m) - .setEfConstruction(efConstruction) - .build(); - PQBuildParams hnswPqParams = new PQBuildParams.Builder() - .setNumSubVectors(numSubVectors != null ? numSubVectors : 16) - .setNumBits(numBits) - .build(); - VectorIndexParams ivfHnswParams = VectorIndexParams.withIvfHnswPqParams( - distanceType, ivfParams, hnswParams, hnswPqParams); - indexParams = new IndexParams.Builder() - .setDistanceType(distanceType) - .setVectorIndexParams(ivfHnswParams) - .build(); - break; - - case IVF_FLAT: - lanceIndexType = IndexType.IVF_FLAT; - VectorIndexParams ivfFlatParams = VectorIndexParams.ivfFlat(numPartitions, distanceType); - indexParams = new IndexParams.Builder() - .setDistanceType(distanceType) - .setVectorIndexParams(ivfFlatParams) - .build(); - break; - - default: - throw new IllegalArgumentException("Unsupported index type: " + indexType); - } - - // Create index - dataset.createIndex( - Collections.singletonList(columnName), - lanceIndexType, - Optional.empty(), // Index name, use default - indexParams, - replace - ); - - long endTime = System.currentTimeMillis(); - long duration = endTime - startTime; - - LOG.info("Vector index build completed, duration: {} ms", duration); - - return new IndexBuildResult( - true, - indexType, - columnName, - datasetPath, - duration, - null - ); - } catch (Exception e) { - LOG.error("Failed to build vector index", e); - return new IndexBuildResult( - false, - indexType, - columnName, - datasetPath, - System.currentTimeMillis() - startTime, - e.getMessage() - ); - } + if (allocator != null) { + try { + allocator.close(); + } catch (Exception e) { + LOG.warn("Failed to close allocator", e); + } + allocator = null; + } + } + + /** Create builder */ + public static Builder builder() { + return new Builder(); + } + + /** Create index builder from LanceOptions */ + public static LanceIndexBuilder fromOptions(LanceOptions options) { + return builder() + .datasetPath(options.getPath()) + .columnName(options.getIndexColumn()) + .indexType(options.getIndexType()) + .metricType(options.getVectorMetric()) + .numPartitions(options.getIndexNumPartitions()) + .numSubVectors(options.getIndexNumSubVectors()) + .numBits(options.getIndexNumBits()) + .maxLevel(options.getIndexMaxLevel()) + .m(options.getIndexM()) + .efConstruction(options.getIndexEfConstruction()) + .build(); + } + + /** Builder */ + public static class Builder { + private String datasetPath; + private String columnName; + private LanceOptions.IndexType indexType = LanceOptions.IndexType.IVF_PQ; + private LanceOptions.MetricType metricType = LanceOptions.MetricType.L2; + private int numPartitions = 256; + private Integer numSubVectors; + private int numBits = 8; + private int maxLevel = 7; + private int m = 16; + private int efConstruction = 100; + private boolean replace = false; + + public Builder datasetPath(String datasetPath) { + this.datasetPath = datasetPath; + return this; } - /** - * Validate vector column exists - */ - private void validateColumn() throws IOException { - // Check if column exists in Schema - boolean columnExists = dataset.getSchema().getFields().stream() - .anyMatch(field -> field.getName().equals(columnName)); - - if (!columnExists) { - throw new IOException("Vector column does not exist: " + columnName); - } + public Builder columnName(String columnName) { + this.columnName = columnName; + return this; } - /** - * Convert distance metric type - */ - private DistanceType toDistanceType(LanceOptions.MetricType metricType) { - switch (metricType) { - case L2: - return DistanceType.L2; - case COSINE: - return DistanceType.Cosine; - case DOT: - return DistanceType.Dot; - default: - return DistanceType.L2; - } + public Builder indexType(LanceOptions.IndexType indexType) { + this.indexType = indexType; + return this; } - @Override - public void close() throws IOException { - if (dataset != null) { - try { - dataset.close(); - } catch (Exception e) { - LOG.warn("Failed to close dataset", e); - } - dataset = null; - } - - if (allocator != null) { - try { - allocator.close(); - } catch (Exception e) { - LOG.warn("Failed to close allocator", e); - } - allocator = null; - } + public Builder metricType(LanceOptions.MetricType metricType) { + this.metricType = metricType; + return this; + } + + public Builder numPartitions(int numPartitions) { + this.numPartitions = numPartitions; + return this; + } + + public Builder numSubVectors(Integer numSubVectors) { + this.numSubVectors = numSubVectors; + return this; + } + + public Builder numBits(int numBits) { + this.numBits = numBits; + return this; + } + + public Builder maxLevel(int maxLevel) { + this.maxLevel = maxLevel; + return this; + } + + // checkstyle.off: MethodName + public Builder m(int m) { + this.m = m; + return this; + } + + // checkstyle.on: MethodName + + public Builder efConstruction(int efConstruction) { + this.efConstruction = efConstruction; + return this; + } + + public Builder replace(boolean replace) { + this.replace = replace; + return this; + } + + public LanceIndexBuilder build() { + validate(); + return new LanceIndexBuilder(this); + } + + private void validate() { + if (datasetPath == null || datasetPath.isEmpty()) { + throw new IllegalArgumentException("Dataset path cannot be empty"); + } + if (columnName == null || columnName.isEmpty()) { + throw new IllegalArgumentException("Column name cannot be empty"); + } + if (numPartitions <= 0) { + throw new IllegalArgumentException("Number of partitions must be greater than 0"); + } + if (numSubVectors != null && numSubVectors <= 0) { + throw new IllegalArgumentException("Number of sub-vectors must be greater than 0"); + } + if (numBits <= 0 || numBits > 16) { + throw new IllegalArgumentException("Quantization bits must be between 1 and 16"); + } + } + } + + /** Index build result */ + public static class IndexBuildResult implements Serializable { + private static final long serialVersionUID = 1L; + + private final boolean success; + private final LanceOptions.IndexType indexType; + private final String columnName; + private final String datasetPath; + private final long durationMillis; + private final String errorMessage; + + public IndexBuildResult( + boolean success, + LanceOptions.IndexType indexType, + String columnName, + String datasetPath, + long durationMillis, + String errorMessage) { + this.success = success; + this.indexType = indexType; + this.columnName = columnName; + this.datasetPath = datasetPath; + this.durationMillis = durationMillis; + this.errorMessage = errorMessage; + } + + public boolean isSuccess() { + return success; + } + + public LanceOptions.IndexType getIndexType() { + return indexType; } - /** - * Create builder - */ - public static Builder builder() { - return new Builder(); + public String getColumnName() { + return columnName; } - /** - * Create index builder from LanceOptions - */ - public static LanceIndexBuilder fromOptions(LanceOptions options) { - return builder() - .datasetPath(options.getPath()) - .columnName(options.getIndexColumn()) - .indexType(options.getIndexType()) - .metricType(options.getVectorMetric()) - .numPartitions(options.getIndexNumPartitions()) - .numSubVectors(options.getIndexNumSubVectors()) - .numBits(options.getIndexNumBits()) - .maxLevel(options.getIndexMaxLevel()) - .m(options.getIndexM()) - .efConstruction(options.getIndexEfConstruction()) - .build(); + public String getDatasetPath() { + return datasetPath; } - /** - * Builder - */ - public static class Builder { - private String datasetPath; - private String columnName; - private LanceOptions.IndexType indexType = LanceOptions.IndexType.IVF_PQ; - private LanceOptions.MetricType metricType = LanceOptions.MetricType.L2; - private int numPartitions = 256; - private Integer numSubVectors; - private int numBits = 8; - private int maxLevel = 7; - private int m = 16; - private int efConstruction = 100; - private boolean replace = false; - - public Builder datasetPath(String datasetPath) { - this.datasetPath = datasetPath; - return this; - } - - public Builder columnName(String columnName) { - this.columnName = columnName; - return this; - } - - public Builder indexType(LanceOptions.IndexType indexType) { - this.indexType = indexType; - return this; - } - - public Builder metricType(LanceOptions.MetricType metricType) { - this.metricType = metricType; - return this; - } - - public Builder numPartitions(int numPartitions) { - this.numPartitions = numPartitions; - return this; - } - - public Builder numSubVectors(Integer numSubVectors) { - this.numSubVectors = numSubVectors; - return this; - } - - public Builder numBits(int numBits) { - this.numBits = numBits; - return this; - } - - public Builder maxLevel(int maxLevel) { - this.maxLevel = maxLevel; - return this; - } - - public Builder m(int m) { - this.m = m; - return this; - } - - public Builder efConstruction(int efConstruction) { - this.efConstruction = efConstruction; - return this; - } - - public Builder replace(boolean replace) { - this.replace = replace; - return this; - } - - public LanceIndexBuilder build() { - validate(); - return new LanceIndexBuilder(this); - } - - private void validate() { - if (datasetPath == null || datasetPath.isEmpty()) { - throw new IllegalArgumentException("Dataset path cannot be empty"); - } - if (columnName == null || columnName.isEmpty()) { - throw new IllegalArgumentException("Column name cannot be empty"); - } - if (numPartitions <= 0) { - throw new IllegalArgumentException("Number of partitions must be greater than 0"); - } - if (numSubVectors != null && numSubVectors <= 0) { - throw new IllegalArgumentException("Number of sub-vectors must be greater than 0"); - } - if (numBits <= 0 || numBits > 16) { - throw new IllegalArgumentException("Quantization bits must be between 1 and 16"); - } - } + public long getDurationMillis() { + return durationMillis; } - /** - * Index build result - */ - public static class IndexBuildResult implements Serializable { - private static final long serialVersionUID = 1L; - - private final boolean success; - private final LanceOptions.IndexType indexType; - private final String columnName; - private final String datasetPath; - private final long durationMillis; - private final String errorMessage; - - public IndexBuildResult(boolean success, LanceOptions.IndexType indexType, String columnName, - String datasetPath, long durationMillis, String errorMessage) { - this.success = success; - this.indexType = indexType; - this.columnName = columnName; - this.datasetPath = datasetPath; - this.durationMillis = durationMillis; - this.errorMessage = errorMessage; - } - - public boolean isSuccess() { - return success; - } - - public LanceOptions.IndexType getIndexType() { - return indexType; - } - - public String getColumnName() { - return columnName; - } - - public String getDatasetPath() { - return datasetPath; - } - - public long getDurationMillis() { - return durationMillis; - } - - public String getErrorMessage() { - return errorMessage; - } - - @Override - public String toString() { - return "IndexBuildResult{" + - "success=" + success + - ", indexType=" + indexType + - ", columnName='" + columnName + '\'' + - ", datasetPath='" + datasetPath + '\'' + - ", durationMillis=" + durationMillis + - ", errorMessage='" + errorMessage + '\'' + - '}'; - } + public String getErrorMessage() { + return errorMessage; + } + + @Override + public String toString() { + return "IndexBuildResult{" + + "success=" + + success + + ", indexType=" + + indexType + + ", columnName='" + + columnName + + '\'' + + ", datasetPath='" + + datasetPath + + '\'' + + ", durationMillis=" + + durationMillis + + ", errorMessage='" + + errorMessage + + '\'' + + '}'; } + } } diff --git a/src/main/java/org/apache/flink/connector/lance/LanceInputFormat.java b/src/main/java/org/apache/flink/connector/lance/LanceInputFormat.java index 2785a00..be58ce3 100644 --- a/src/main/java/org/apache/flink/connector/lance/LanceInputFormat.java +++ b/src/main/java/org/apache/flink/connector/lance/LanceInputFormat.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.api.common.io.RichInputFormat; @@ -41,7 +36,6 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; @@ -49,287 +43,279 @@ /** * Lance InputFormat implementation. - * - *

Reads data from Lance dataset using InputFormat interface, supports parallel reading with splits. + * + *

Reads data from Lance dataset using InputFormat interface, supports parallel reading with + * splits. */ public class LanceInputFormat extends RichInputFormat { - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(LanceInputFormat.class); - - private final LanceOptions options; - private final RowType rowType; - private final String[] selectedColumns; - - private transient BufferAllocator allocator; - private transient Dataset dataset; - private transient RowDataConverter converter; - private transient LanceScanner currentScanner; - private transient ArrowReader currentReader; - private transient Iterator currentBatchIterator; - private transient boolean reachedEnd; - - /** - * Create LanceInputFormat - * - * @param options Lance configuration options - * @param rowType Flink RowType - */ - public LanceInputFormat(LanceOptions options, RowType rowType) { - this.options = options; - this.rowType = rowType; - - List columns = options.getReadColumns(); - this.selectedColumns = columns != null && !columns.isEmpty() - ? columns.toArray(new String[0]) - : null; + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(LanceInputFormat.class); + + private final LanceOptions options; + private final RowType rowType; + private final String[] selectedColumns; + + private transient BufferAllocator allocator; + private transient Dataset dataset; + private transient RowDataConverter converter; + private transient LanceScanner currentScanner; + private transient ArrowReader currentReader; + private transient Iterator currentBatchIterator; + private transient boolean reachedEnd; + + /** + * Create LanceInputFormat + * + * @param options Lance configuration options + * @param rowType Flink RowType + */ + public LanceInputFormat(LanceOptions options, RowType rowType) { + this.options = options; + this.rowType = rowType; + + List columns = options.getReadColumns(); + this.selectedColumns = + columns != null && !columns.isEmpty() ? columns.toArray(new String[0]) : null; + } + + @Override + public void configure(Configuration parameters) { + // Configuration already done in constructor + } + + @Override + public BaseStatistics getStatistics(BaseStatistics cachedStatistics) throws IOException { + // Return basic statistics + return cachedStatistics; + } + + @Override + public LanceSplit[] createInputSplits(int minNumSplits) throws IOException { + LOG.info("Creating input splits, minimum split count: {}", minNumSplits); + + String datasetPath = options.getPath(); + if (datasetPath == null || datasetPath.isEmpty()) { + throw new IOException("Dataset path cannot be empty"); } - @Override - public void configure(Configuration parameters) { - // Configuration already done in constructor + BufferAllocator tempAllocator = new RootAllocator(Long.MAX_VALUE); + try { + Dataset tempDataset = Dataset.open(datasetPath, tempAllocator); + try { + List fragments = tempDataset.getFragments(); + LanceSplit[] splits = new LanceSplit[fragments.size()]; + + for (int i = 0; i < fragments.size(); i++) { + Fragment fragment = fragments.get(i); + long rowCount = fragment.countRows(); + splits[i] = new LanceSplit(i, fragment.getId(), datasetPath, rowCount); + } + + LOG.info("Created {} input splits", splits.length); + return splits; + } finally { + tempDataset.close(); + } + } finally { + tempAllocator.close(); } + } - @Override - public BaseStatistics getStatistics(BaseStatistics cachedStatistics) throws IOException { - // Return basic statistics - return cachedStatistics; + @Override + public InputSplitAssigner getInputSplitAssigner(LanceSplit[] inputSplits) { + return new LanceSplitAssigner(inputSplits); + } + + @Override + public void open(LanceSplit split) throws IOException { + LOG.info("Opening split: {}", split); + + this.allocator = new RootAllocator(Long.MAX_VALUE); + this.reachedEnd = false; + + // Open dataset + String datasetPath = split.getDatasetPath(); + try { + this.dataset = Dataset.open(datasetPath, allocator); + } catch (Exception e) { + throw new IOException("Cannot open dataset: " + datasetPath, e); } - @Override - public LanceSplit[] createInputSplits(int minNumSplits) throws IOException { - LOG.info("Creating input splits, minimum split count: {}", minNumSplits); - - String datasetPath = options.getPath(); - if (datasetPath == null || datasetPath.isEmpty()) { - throw new IOException("Dataset path cannot be empty"); - } + // Initialize converter + RowType actualRowType = this.rowType; + if (actualRowType == null) { + Schema arrowSchema = dataset.getSchema(); + actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema); + } + this.converter = new RowDataConverter(actualRowType); - BufferAllocator tempAllocator = new RootAllocator(Long.MAX_VALUE); - try { - Dataset tempDataset = Dataset.open(datasetPath, tempAllocator); - try { - List fragments = tempDataset.getFragments(); - LanceSplit[] splits = new LanceSplit[fragments.size()]; - - for (int i = 0; i < fragments.size(); i++) { - Fragment fragment = fragments.get(i); - long rowCount = fragment.countRows(); - splits[i] = new LanceSplit(i, fragment.getId(), datasetPath, rowCount); - } - - LOG.info("Created {} input splits", splits.length); - return splits; - } finally { - tempDataset.close(); - } - } finally { - tempAllocator.close(); - } + // Get specified Fragment + List fragments = dataset.getFragments(); + Fragment targetFragment = null; + for (Fragment fragment : fragments) { + if (fragment.getId() == split.getFragmentId()) { + targetFragment = fragment; + break; + } } - @Override - public InputSplitAssigner getInputSplitAssigner(LanceSplit[] inputSplits) { - return new LanceSplitAssigner(inputSplits); + if (targetFragment == null) { + throw new IOException("Cannot find Fragment: " + split.getFragmentId()); } - @Override - public void open(LanceSplit split) throws IOException { - LOG.info("Opening split: {}", split); - - this.allocator = new RootAllocator(Long.MAX_VALUE); - this.reachedEnd = false; - - // Open dataset - String datasetPath = split.getDatasetPath(); - try { - this.dataset = Dataset.open(datasetPath, allocator); - } catch (Exception e) { - throw new IOException("Cannot open dataset: " + datasetPath, e); - } - - // Initialize converter - RowType actualRowType = this.rowType; - if (actualRowType == null) { - Schema arrowSchema = dataset.getSchema(); - actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema); - } - this.converter = new RowDataConverter(actualRowType); - - // Get specified Fragment - List fragments = dataset.getFragments(); - Fragment targetFragment = null; - for (Fragment fragment : fragments) { - if (fragment.getId() == split.getFragmentId()) { - targetFragment = fragment; - break; - } - } - - if (targetFragment == null) { - throw new IOException("Cannot find Fragment: " + split.getFragmentId()); - } - - // Build scan options - ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder(); - scanOptionsBuilder.batchSize(options.getReadBatchSize()); - - if (selectedColumns != null && selectedColumns.length > 0) { - scanOptionsBuilder.columns(Arrays.asList(selectedColumns)); - } - - String filter = options.getReadFilter(); - if (filter != null && !filter.isEmpty()) { - scanOptionsBuilder.filter(filter); - } - - ScanOptions scanOptions = scanOptionsBuilder.build(); - - // Create Scanner - try { - this.currentScanner = targetFragment.newScan(scanOptions); - this.currentReader = currentScanner.scanBatches(); - } catch (Exception e) { - throw new IOException("Failed to create Scanner", e); - } - - // Load first batch of data - loadNextBatch(); + // Build scan options + ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder(); + scanOptionsBuilder.batchSize(options.getReadBatchSize()); + + if (selectedColumns != null && selectedColumns.length > 0) { + scanOptionsBuilder.columns(Arrays.asList(selectedColumns)); } - /** - * Load next batch of data - */ - private void loadNextBatch() throws IOException { - try { - if (currentReader.loadNextBatch()) { - VectorSchemaRoot root = currentReader.getVectorSchemaRoot(); - List rows = converter.toRowDataList(root); - this.currentBatchIterator = rows.iterator(); - } else { - this.reachedEnd = true; - this.currentBatchIterator = null; - } - } catch (Exception e) { - throw new IOException("Failed to load data batch", e); - } + String filter = options.getReadFilter(); + if (filter != null && !filter.isEmpty()) { + scanOptionsBuilder.filter(filter); } - @Override - public boolean reachedEnd() throws IOException { - return reachedEnd; + ScanOptions scanOptions = scanOptionsBuilder.build(); + + // Create Scanner + try { + this.currentScanner = targetFragment.newScan(scanOptions); + this.currentReader = currentScanner.scanBatches(); + } catch (Exception e) { + throw new IOException("Failed to create Scanner", e); } - @Override - public RowData nextRecord(RowData reuse) throws IOException { - if (reachedEnd) { - return null; - } - - // Current batch still has data - if (currentBatchIterator != null && currentBatchIterator.hasNext()) { - return currentBatchIterator.next(); - } - - // Load next batch - loadNextBatch(); - - if (reachedEnd) { - return null; - } - - if (currentBatchIterator != null && currentBatchIterator.hasNext()) { - return currentBatchIterator.next(); - } - - return null; + // Load first batch of data + loadNextBatch(); + } + + /** Load next batch of data */ + private void loadNextBatch() throws IOException { + try { + if (currentReader.loadNextBatch()) { + VectorSchemaRoot root = currentReader.getVectorSchemaRoot(); + List rows = converter.toRowDataList(root); + this.currentBatchIterator = rows.iterator(); + } else { + this.reachedEnd = true; + this.currentBatchIterator = null; + } + } catch (Exception e) { + throw new IOException("Failed to load data batch", e); } + } - @Override - public void close() throws IOException { - LOG.info("Closing LanceInputFormat"); - - if (currentReader != null) { - try { - currentReader.close(); - } catch (Exception e) { - LOG.warn("Failed to close Reader", e); - } - currentReader = null; - } - - if (currentScanner != null) { - try { - currentScanner.close(); - } catch (Exception e) { - LOG.warn("Failed to close Scanner", e); - } - currentScanner = null; - } - - if (dataset != null) { - try { - dataset.close(); - } catch (Exception e) { - LOG.warn("Failed to close dataset", e); - } - dataset = null; - } - - if (allocator != null) { - try { - allocator.close(); - } catch (Exception e) { - LOG.warn("Failed to close allocator", e); - } - allocator = null; - } + @Override + public boolean reachedEnd() throws IOException { + return reachedEnd; + } + + @Override + public RowData nextRecord(RowData reuse) throws IOException { + if (reachedEnd) { + return null; } - /** - * Get RowType - */ - public RowType getRowType() { - return rowType; + // Current batch still has data + if (currentBatchIterator != null && currentBatchIterator.hasNext()) { + return currentBatchIterator.next(); } - /** - * Get configuration options - */ - public LanceOptions getOptions() { - return options; + // Load next batch + loadNextBatch(); + + if (reachedEnd) { + return null; } - /** - * Lance split assigner - */ - private static class LanceSplitAssigner implements InputSplitAssigner { - private final List remainingSplits; - - public LanceSplitAssigner(LanceSplit[] splits) { - this.remainingSplits = new ArrayList<>(); - for (LanceSplit split : splits) { - remainingSplits.add(split); - } - } + if (currentBatchIterator != null && currentBatchIterator.hasNext()) { + return currentBatchIterator.next(); + } - @Override - public synchronized LanceSplit getNextInputSplit(String host, int taskId) { - if (remainingSplits.isEmpty()) { - return null; - } - return remainingSplits.remove(remainingSplits.size() - 1); - } + return null; + } - @Override - public void returnInputSplit(List splits, int taskId) { - for (org.apache.flink.core.io.InputSplit split : splits) { - if (split instanceof LanceSplit) { - synchronized (this) { - remainingSplits.add((LanceSplit) split); - } - } - } + @Override + public void close() throws IOException { + LOG.info("Closing LanceInputFormat"); + + if (currentReader != null) { + try { + currentReader.close(); + } catch (Exception e) { + LOG.warn("Failed to close Reader", e); + } + currentReader = null; + } + + if (currentScanner != null) { + try { + currentScanner.close(); + } catch (Exception e) { + LOG.warn("Failed to close Scanner", e); + } + currentScanner = null; + } + + if (dataset != null) { + try { + dataset.close(); + } catch (Exception e) { + LOG.warn("Failed to close dataset", e); + } + dataset = null; + } + + if (allocator != null) { + try { + allocator.close(); + } catch (Exception e) { + LOG.warn("Failed to close allocator", e); + } + allocator = null; + } + } + + /** Get RowType */ + public RowType getRowType() { + return rowType; + } + + /** Get configuration options */ + public LanceOptions getOptions() { + return options; + } + + /** Lance split assigner */ + private static class LanceSplitAssigner implements InputSplitAssigner { + private final List remainingSplits; + + LanceSplitAssigner(LanceSplit[] splits) { + this.remainingSplits = new ArrayList<>(); + for (LanceSplit split : splits) { + remainingSplits.add(split); + } + } + + @Override + public synchronized LanceSplit getNextInputSplit(String host, int taskId) { + if (remainingSplits.isEmpty()) { + return null; + } + return remainingSplits.remove(remainingSplits.size() - 1); + } + + @Override + public void returnInputSplit(List splits, int taskId) { + for (org.apache.flink.core.io.InputSplit split : splits) { + if (split instanceof LanceSplit) { + synchronized (this) { + remainingSplits.add((LanceSplit) split); + } } + } } + } } diff --git a/src/main/java/org/apache/flink/connector/lance/LanceSink.java b/src/main/java/org/apache/flink/connector/lance/LanceSink.java index feeec90..842b990 100644 --- a/src/main/java/org/apache/flink/connector/lance/LanceSink.java +++ b/src/main/java/org/apache/flink/connector/lance/LanceSink.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.configuration.Configuration; @@ -52,294 +47,282 @@ /** * Lance Sink implementation. - * + * *

Writes Flink RowData to Lance dataset, supports batch writing and Checkpoint. - * + * *

Usage example: + * *

{@code
  * LanceOptions options = LanceOptions.builder()
  *     .path("/path/to/lance/dataset")
  *     .writeBatchSize(1024)
  *     .writeMode(WriteMode.APPEND)
  *     .build();
- * 
+ *
  * LanceSink sink = new LanceSink(options, rowType);
  * dataStream.addSink(sink);
  * }
*/ public class LanceSink extends RichSinkFunction implements CheckpointedFunction { - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(LanceSink.class); - - private final LanceOptions options; - private final RowType rowType; - - private transient BufferAllocator allocator; - private transient Dataset dataset; - private transient RowDataConverter converter; - private transient Schema arrowSchema; - private transient List buffer; - private transient long totalWrittenRows; - private transient boolean datasetExists; - private transient boolean isFirstWrite; - - /** - * Create LanceSink - * - * @param options Lance configuration options - * @param rowType Flink RowType - */ - public LanceSink(LanceOptions options, RowType rowType) { - this.options = options; - this.rowType = rowType; + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(LanceSink.class); + + private final LanceOptions options; + private final RowType rowType; + + private transient BufferAllocator allocator; + private transient Dataset dataset; + private transient RowDataConverter converter; + private transient Schema arrowSchema; + private transient List buffer; + private transient long totalWrittenRows; + private transient boolean datasetExists; + private transient boolean isFirstWrite; + + /** + * Create LanceSink + * + * @param options Lance configuration options + * @param rowType Flink RowType + */ + public LanceSink(LanceOptions options, RowType rowType) { + this.options = options; + this.rowType = rowType; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + + LOG.info("Opening Lance Sink: {}", options.getPath()); + + this.allocator = new RootAllocator(Long.MAX_VALUE); + this.buffer = new ArrayList<>(options.getWriteBatchSize()); + this.totalWrittenRows = 0; + this.isFirstWrite = true; + + // Initialize converter and Schema + this.converter = new RowDataConverter(rowType); + this.arrowSchema = LanceTypeConverter.toArrowSchema(rowType); + + // Check if dataset exists + String datasetPath = options.getPath(); + if (datasetPath == null || datasetPath.isEmpty()) { + throw new IllegalArgumentException("Lance dataset path cannot be empty"); } - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - - LOG.info("Opening Lance Sink: {}", options.getPath()); - - this.allocator = new RootAllocator(Long.MAX_VALUE); - this.buffer = new ArrayList<>(options.getWriteBatchSize()); - this.totalWrittenRows = 0; - this.isFirstWrite = true; - - // Initialize converter and Schema - this.converter = new RowDataConverter(rowType); - this.arrowSchema = LanceTypeConverter.toArrowSchema(rowType); - - // Check if dataset exists - String datasetPath = options.getPath(); - if (datasetPath == null || datasetPath.isEmpty()) { - throw new IllegalArgumentException("Lance dataset path cannot be empty"); - } - - Path path = Paths.get(datasetPath); - this.datasetExists = Files.exists(path); - - // If overwrite mode and dataset exists, delete first - if (datasetExists && options.getWriteMode() == LanceOptions.WriteMode.OVERWRITE) { - LOG.info("Overwrite mode, deleting existing dataset: {}", datasetPath); - deleteDirectory(path); - this.datasetExists = false; - } - - LOG.info("Lance Sink opened, Schema: {}", rowType); + Path path = Paths.get(datasetPath); + this.datasetExists = Files.exists(path); + + // If overwrite mode and dataset exists, delete first + if (datasetExists && options.getWriteMode() == LanceOptions.WriteMode.OVERWRITE) { + LOG.info("Overwrite mode, deleting existing dataset: {}", datasetPath); + deleteDirectory(path); + this.datasetExists = false; } - @Override - public void invoke(RowData value, Context context) throws Exception { - buffer.add(value); - - // When buffer reaches batch size, execute write - if (buffer.size() >= options.getWriteBatchSize()) { - flush(); - } + LOG.info("Lance Sink opened, Schema: {}", rowType); + } + + @Override + public void invoke(RowData value, Context context) throws Exception { + buffer.add(value); + + // When buffer reaches batch size, execute write + if (buffer.size() >= options.getWriteBatchSize()) { + flush(); } + } - /** - * Flush buffer, write data to Lance dataset - */ - public void flush() throws IOException { - if (buffer.isEmpty()) { - return; - } - - LOG.debug("Flushing buffer, row count: {}", buffer.size()); - - try (VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, allocator)) { - // Convert RowData to VectorSchemaRoot - converter.toVectorSchemaRoot(buffer, root); - - String datasetPath = options.getPath(); - - // Build write parameters - WriteParams writeParams = new WriteParams.Builder() - .withMaxRowsPerFile(options.getWriteMaxRowsPerFile()) - .build(); - - // Create Fragment - List fragments = Fragment.create( - datasetPath, - allocator, - root, - writeParams - ); - - if (!datasetExists) { - // Create new dataset (using Overwrite operation) - FragmentOperation.Overwrite overwrite = new FragmentOperation.Overwrite(fragments, arrowSchema); - dataset = overwrite.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap()); - datasetExists = true; - isFirstWrite = false; - LOG.info("Created new dataset: {}", datasetPath); - } else { - // Append data - if (isFirstWrite && options.getWriteMode() == LanceOptions.WriteMode.OVERWRITE) { - // First write and overwrite mode - FragmentOperation.Overwrite overwrite = new FragmentOperation.Overwrite(fragments, arrowSchema); - dataset = overwrite.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap()); - isFirstWrite = false; - } else { - // Append mode - FragmentOperation.Append append = new FragmentOperation.Append(fragments); - dataset = append.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap()); - } - } - - totalWrittenRows += buffer.size(); - LOG.debug("Written {} rows, total: {} rows", buffer.size(), totalWrittenRows); - - buffer.clear(); - } catch (Exception e) { - throw new IOException("Failed to write Lance dataset", e); - } + /** Flush buffer, write data to Lance dataset */ + public void flush() throws IOException { + if (buffer.isEmpty()) { + return; } - @Override - public void close() throws Exception { - LOG.info("Closing Lance Sink"); - // Flush remaining data - try { - flush(); - } catch (Exception e) { - LOG.warn("Failed to flush data on close", e); - } - if (dataset != null) { - try { - dataset.close(); - } catch (Exception e) { - LOG.warn("Failed to close dataset", e); - } - dataset = null; - } - - if (allocator != null) { - try { - allocator.close(); - } catch (Exception e) { - LOG.warn("Failed to close allocator", e); - } - allocator = null; + LOG.debug("Flushing buffer, row count: {}", buffer.size()); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, allocator)) { + // Convert RowData to VectorSchemaRoot + converter.toVectorSchemaRoot(buffer, root); + + String datasetPath = options.getPath(); + + // Build write parameters + WriteParams writeParams = + new WriteParams.Builder().withMaxRowsPerFile(options.getWriteMaxRowsPerFile()).build(); + + // Create Fragment + List fragments = Fragment.create(datasetPath, allocator, root, writeParams); + + if (!datasetExists) { + // Create new dataset (using Overwrite operation) + FragmentOperation.Overwrite overwrite = + new FragmentOperation.Overwrite(fragments, arrowSchema); + dataset = + overwrite.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap()); + datasetExists = true; + isFirstWrite = false; + LOG.info("Created new dataset: {}", datasetPath); + } else { + // Append data + if (isFirstWrite && options.getWriteMode() == LanceOptions.WriteMode.OVERWRITE) { + // First write and overwrite mode + FragmentOperation.Overwrite overwrite = + new FragmentOperation.Overwrite(fragments, arrowSchema); + dataset = + overwrite.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap()); + isFirstWrite = false; + } else { + // Append mode + FragmentOperation.Append append = new FragmentOperation.Append(fragments); + dataset = append.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap()); } - - LOG.info("Lance Sink closed, total written {} rows", totalWrittenRows); - - super.close(); - } + } - @Override - public void snapshotState(FunctionSnapshotContext context) throws Exception { - LOG.debug("Snapshot state, checkpointId: {}", context.getCheckpointId()); - - // Flush all buffered data at Checkpoint - flush(); - } + totalWrittenRows += buffer.size(); + LOG.debug("Written {} rows, total: {} rows", buffer.size(), totalWrittenRows); - @Override - public void initializeState(FunctionInitializationContext context) throws Exception { - LOG.debug("Initialize state, isRestored: {}", context.isRestored()); - // State initialization (if recovery needed) + buffer.clear(); + } catch (Exception e) { + throw new IOException("Failed to write Lance dataset", e); } + } - /** - * Get RowType - */ - public RowType getRowType() { - return rowType; + @Override + public void close() throws Exception { + LOG.info("Closing Lance Sink"); + // Flush remaining data + try { + flush(); + } catch (Exception e) { + LOG.warn("Failed to flush data on close", e); } - - /** - * Get configuration options - */ - public LanceOptions getOptions() { - return options; + if (dataset != null) { + try { + dataset.close(); + } catch (Exception e) { + LOG.warn("Failed to close dataset", e); + } + dataset = null; } - /** - * Get total written row count - */ - public long getTotalWrittenRows() { - return totalWrittenRows; + if (allocator != null) { + try { + allocator.close(); + } catch (Exception e) { + LOG.warn("Failed to close allocator", e); + } + allocator = null; } - /** - * Recursively delete directory - */ - private void deleteDirectory(Path path) throws IOException { - if (Files.isDirectory(path)) { - Files.list(path).forEach(child -> { + LOG.info("Lance Sink closed, total written {} rows", totalWrittenRows); + + super.close(); + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + LOG.debug("Snapshot state, checkpointId: {}", context.getCheckpointId()); + + // Flush all buffered data at Checkpoint + flush(); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + LOG.debug("Initialize state, isRestored: {}", context.isRestored()); + // State initialization (if recovery needed) + } + + /** Get RowType */ + public RowType getRowType() { + return rowType; + } + + /** Get configuration options */ + public LanceOptions getOptions() { + return options; + } + + /** Get total written row count */ + public long getTotalWrittenRows() { + return totalWrittenRows; + } + + /** Recursively delete directory */ + private void deleteDirectory(Path path) throws IOException { + if (Files.isDirectory(path)) { + Files.list(path) + .forEach( + child -> { try { - deleteDirectory(child); + deleteDirectory(child); } catch (IOException e) { - LOG.warn("Failed to delete file: {}", child, e); + LOG.warn("Failed to delete file: {}", child, e); } - }); - } - Files.deleteIfExists(path); + }); } + Files.deleteIfExists(path); + } + + /** Builder pattern constructor */ + public static Builder builder() { + return new Builder(); + } - /** - * Builder pattern constructor - */ - public static Builder builder() { - return new Builder(); + /** LanceSink Builder */ + public static class Builder { + private String path; + private int batchSize = 1024; + private LanceOptions.WriteMode writeMode = LanceOptions.WriteMode.APPEND; + private int maxRowsPerFile = 1000000; + private RowType rowType; + + public Builder path(String path) { + this.path = path; + return this; } - /** - * LanceSink Builder - */ - public static class Builder { - private String path; - private int batchSize = 1024; - private LanceOptions.WriteMode writeMode = LanceOptions.WriteMode.APPEND; - private int maxRowsPerFile = 1000000; - private RowType rowType; - - public Builder path(String path) { - this.path = path; - return this; - } + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } - public Builder batchSize(int batchSize) { - this.batchSize = batchSize; - return this; - } + public Builder writeMode(LanceOptions.WriteMode writeMode) { + this.writeMode = writeMode; + return this; + } - public Builder writeMode(LanceOptions.WriteMode writeMode) { - this.writeMode = writeMode; - return this; - } + public Builder maxRowsPerFile(int maxRowsPerFile) { + this.maxRowsPerFile = maxRowsPerFile; + return this; + } - public Builder maxRowsPerFile(int maxRowsPerFile) { - this.maxRowsPerFile = maxRowsPerFile; - return this; - } + public Builder rowType(RowType rowType) { + this.rowType = rowType; + return this; + } - public Builder rowType(RowType rowType) { - this.rowType = rowType; - return this; - } + public LanceSink build() { + if (path == null || path.isEmpty()) { + throw new IllegalArgumentException("Dataset path cannot be empty"); + } - public LanceSink build() { - if (path == null || path.isEmpty()) { - throw new IllegalArgumentException("Dataset path cannot be empty"); - } - - if (rowType == null) { - throw new IllegalArgumentException("RowType cannot be null"); - } - - LanceOptions options = LanceOptions.builder() - .path(path) - .writeBatchSize(batchSize) - .writeMode(writeMode) - .writeMaxRowsPerFile(maxRowsPerFile) - .build(); - - return new LanceSink(options, rowType); - } + if (rowType == null) { + throw new IllegalArgumentException("RowType cannot be null"); + } + + LanceOptions options = + LanceOptions.builder() + .path(path) + .writeBatchSize(batchSize) + .writeMode(writeMode) + .writeMaxRowsPerFile(maxRowsPerFile) + .build(); + + return new LanceSink(options, rowType); } + } } diff --git a/src/main/java/org/apache/flink/connector/lance/LanceSource.java b/src/main/java/org/apache/flink/connector/lance/LanceSource.java index ade00a0..6b15699 100644 --- a/src/main/java/org/apache/flink/connector/lance/LanceSource.java +++ b/src/main/java/org/apache/flink/connector/lance/LanceSource.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,10 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; -import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.configuration.Configuration; import org.apache.flink.connector.lance.config.LanceOptions; import org.apache.flink.connector.lance.converter.LanceTypeConverter; @@ -47,364 +41,354 @@ /** * Lance data source implementation. - * + * *

Reads data from Lance dataset and converts to Flink RowData. + * *

Supports column pruning, predicate push-down and Limit push-down optimization. - * + * *

Usage example: + * *

{@code
  * LanceOptions options = LanceOptions.builder()
  *     .path("/path/to/lance/dataset")
  *     .readBatchSize(1024)
  *     .readLimit(100L)  // Limit push-down
  *     .build();
- * 
+ *
  * LanceSource source = new LanceSource(options, rowType);
  * DataStream stream = env.addSource(source);
  * }
*/ public class LanceSource extends RichParallelSourceFunction { - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(LanceSource.class); - - private final LanceOptions options; - private final RowType rowType; - private final String[] selectedColumns; - private final Long readLimit; // Added: Limit push-down - - private transient volatile boolean running; - private transient BufferAllocator allocator; - private transient Dataset dataset; - private transient RowDataConverter converter; - private transient long emittedCount; // Added: emitted row count - - /** - * Create LanceSource - * - * @param options Lance configuration options - * @param rowType Flink RowType - */ - public LanceSource(LanceOptions options, RowType rowType) { - this.options = options; - this.rowType = rowType; - - List columns = options.getReadColumns(); - this.selectedColumns = columns != null && !columns.isEmpty() - ? columns.toArray(new String[0]) - : null; - this.readLimit = options.getReadLimit(); + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(LanceSource.class); + + private final LanceOptions options; + private final RowType rowType; + private final String[] selectedColumns; + private final Long readLimit; // Added: Limit push-down + + private transient volatile boolean running; + private transient BufferAllocator allocator; + private transient Dataset dataset; + private transient RowDataConverter converter; + private transient long emittedCount; // Added: emitted row count + + /** + * Create LanceSource + * + * @param options Lance configuration options + * @param rowType Flink RowType + */ + public LanceSource(LanceOptions options, RowType rowType) { + this.options = options; + this.rowType = rowType; + + List columns = options.getReadColumns(); + this.selectedColumns = + columns != null && !columns.isEmpty() ? columns.toArray(new String[0]) : null; + this.readLimit = options.getReadLimit(); + } + + /** + * Create LanceSource (auto-infer Schema) + * + * @param options Lance configuration options + */ + public LanceSource(LanceOptions options) { + this(options, null); + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + + LOG.info("Opening Lance data source: {}", options.getPath()); + if (readLimit != null) { + LOG.info("Limit push-down enabled, max read rows: {}", readLimit); } - /** - * Create LanceSource (auto-infer Schema) - * - * @param options Lance configuration options - */ - public LanceSource(LanceOptions options) { - this(options, null); + this.running = true; + this.emittedCount = 0; + this.allocator = new RootAllocator(Long.MAX_VALUE); + + // Open Lance dataset + String datasetPath = options.getPath(); + if (datasetPath == null || datasetPath.isEmpty()) { + throw new IllegalArgumentException("Lance dataset path cannot be empty"); } - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - - LOG.info("Opening Lance data source: {}", options.getPath()); - if (readLimit != null) { - LOG.info("Limit push-down enabled, max read rows: {}", readLimit); - } - - this.running = true; - this.emittedCount = 0; - this.allocator = new RootAllocator(Long.MAX_VALUE); - - // Open Lance dataset - String datasetPath = options.getPath(); - if (datasetPath == null || datasetPath.isEmpty()) { - throw new IllegalArgumentException("Lance dataset path cannot be empty"); - } - - Path path = Paths.get(datasetPath); - try { - this.dataset = Dataset.open(path.toString(), allocator); - } catch (Exception e) { - throw new IOException("Cannot open Lance dataset: " + datasetPath, e); - } - - // Initialize RowDataConverter - RowType actualRowType = this.rowType; - if (actualRowType == null) { - // Infer RowType from dataset Schema - Schema arrowSchema = dataset.getSchema(); - actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema); - } - this.converter = new RowDataConverter(actualRowType); - - LOG.info("Lance data source opened, Schema: {}", actualRowType); + Path path = Paths.get(datasetPath); + try { + this.dataset = Dataset.open(path.toString(), allocator); + } catch (Exception e) { + throw new IOException("Cannot open Lance dataset: " + datasetPath, e); } - @Override - public void run(SourceContext ctx) throws Exception { - LOG.info("Start reading Lance dataset: {}", options.getPath()); - - int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); - int numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); - - String filter = options.getReadFilter(); - - // If filter condition exists, use Dataset level scan (only execute on first subtask to avoid duplicate data) - if (filter != null && !filter.isEmpty()) { - if (subtaskIndex == 0) { - LOG.info("Using Dataset level scan (with filter condition)"); - readDatasetWithFilter(ctx); - } else { - LOG.info("Subtask {} skipped (only subtask 0 executes in filter mode)", subtaskIndex); - } - } else if (readLimit != null) { - // With Limit, only execute on first subtask to avoid duplicate data - if (subtaskIndex == 0) { - LOG.info("Using Dataset level scan (with Limit)"); - readDatasetWithFilter(ctx); - } else { - LOG.info("Subtask {} skipped (only subtask 0 executes in Limit mode)", subtaskIndex); - } - } else { - // Without filter condition and Limit, use Fragment level parallel scan - List fragments = dataset.getFragments(); - LOG.info("Dataset has {} Fragments, current subtask {}/{}", - fragments.size(), subtaskIndex, numSubtasks); - - // Assign Fragments by subtask - for (int i = 0; i < fragments.size() && running && !isLimitReached(); i++) { - // Simple round-robin assignment strategy - if (i % numSubtasks != subtaskIndex) { - continue; - } - - Fragment fragment = fragments.get(i); - readFragment(ctx, fragment); - } - } - - LOG.info("Lance data source read completed, total emitted {} rows", emittedCount); + // Initialize RowDataConverter + RowType actualRowType = this.rowType; + if (actualRowType == null) { + // Infer RowType from dataset Schema + Schema arrowSchema = dataset.getSchema(); + actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema); } + this.converter = new RowDataConverter(actualRowType); - /** - * Use Dataset level scan (supports filter conditions and Limit) - */ - private void readDatasetWithFilter(SourceContext ctx) throws Exception { - // Build scan options - ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder(); - - // Set batch size - scanOptionsBuilder.batchSize(options.getReadBatchSize()); - - // Set column filter - if (selectedColumns != null && selectedColumns.length > 0) { - scanOptionsBuilder.columns(Arrays.asList(selectedColumns)); - } - - // Set data filter condition - String filter = options.getReadFilter(); - if (filter != null && !filter.isEmpty()) { - LOG.info("Applying filter condition: {}", filter); - scanOptionsBuilder.filter(filter); - } - - ScanOptions scanOptions = scanOptionsBuilder.build(); - - // Use Dataset level scan - try (LanceScanner scanner = dataset.newScan(scanOptions)) { - try (ArrowReader reader = scanner.scanBatches()) { - while (reader.loadNextBatch() && running && !isLimitReached()) { - VectorSchemaRoot root = reader.getVectorSchemaRoot(); - - // Convert to RowData and output - List rows = converter.toRowDataList(root); - synchronized (ctx.getCheckpointLock()) { - for (RowData row : rows) { - if (isLimitReached()) { - break; - } - ctx.collect(row); - emittedCount++; - } - } - } - } - } - - if (isLimitReached()) { - LOG.info("Reached Limit ({}), stop reading", readLimit); + LOG.info("Lance data source opened, Schema: {}", actualRowType); + } + + @Override + public void run(SourceContext ctx) throws Exception { + LOG.info("Start reading Lance dataset: {}", options.getPath()); + + int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); + int numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); + + String filter = options.getReadFilter(); + + // If filter condition exists, use Dataset level scan (only execute on first subtask to avoid + // duplicate data) + if (filter != null && !filter.isEmpty()) { + if (subtaskIndex == 0) { + LOG.info("Using Dataset level scan (with filter condition)"); + readDatasetWithFilter(ctx); + } else { + LOG.info("Subtask {} skipped (only subtask 0 executes in filter mode)", subtaskIndex); + } + } else if (readLimit != null) { + // With Limit, only execute on first subtask to avoid duplicate data + if (subtaskIndex == 0) { + LOG.info("Using Dataset level scan (with Limit)"); + readDatasetWithFilter(ctx); + } else { + LOG.info("Subtask {} skipped (only subtask 0 executes in Limit mode)", subtaskIndex); + } + } else { + // Without filter condition and Limit, use Fragment level parallel scan + List fragments = dataset.getFragments(); + LOG.info( + "Dataset has {} Fragments, current subtask {}/{}", + fragments.size(), + subtaskIndex, + numSubtasks); + + // Assign Fragments by subtask + for (int i = 0; i < fragments.size() && running && !isLimitReached(); i++) { + // Simple round-robin assignment strategy + if (i % numSubtasks != subtaskIndex) { + continue; } + + Fragment fragment = fragments.get(i); + readFragment(ctx, fragment); + } } - /** - * Read single Fragment (without filter condition, but supports Limit) - */ - private void readFragment(SourceContext ctx, Fragment fragment) throws Exception { - LOG.debug("Reading Fragment: {}", fragment.getId()); - - // Build scan options - ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder(); - - // Set batch size - scanOptionsBuilder.batchSize(options.getReadBatchSize()); - - // Set column filter - if (selectedColumns != null && selectedColumns.length > 0) { - scanOptionsBuilder.columns(Arrays.asList(selectedColumns)); - } - - // Note: Fragment level scan does not use filter, filter is only supported at Dataset level - - ScanOptions scanOptions = scanOptionsBuilder.build(); - - // Create Scanner and read data - try (LanceScanner scanner = fragment.newScan(scanOptions)) { - try (ArrowReader reader = scanner.scanBatches()) { - while (reader.loadNextBatch() && running && !isLimitReached()) { - VectorSchemaRoot root = reader.getVectorSchemaRoot(); - - // Convert to RowData and output - List rows = converter.toRowDataList(root); - synchronized (ctx.getCheckpointLock()) { - for (RowData row : rows) { - if (isLimitReached()) { - break; - } - ctx.collect(row); - emittedCount++; - } - } - } + LOG.info("Lance data source read completed, total emitted {} rows", emittedCount); + } + + /** Use Dataset level scan (supports filter conditions and Limit) */ + private void readDatasetWithFilter(SourceContext ctx) throws Exception { + // Build scan options + ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder(); + + // Set batch size + scanOptionsBuilder.batchSize(options.getReadBatchSize()); + + // Set column filter + if (selectedColumns != null && selectedColumns.length > 0) { + scanOptionsBuilder.columns(Arrays.asList(selectedColumns)); + } + + // Set data filter condition + String filter = options.getReadFilter(); + if (filter != null && !filter.isEmpty()) { + LOG.info("Applying filter condition: {}", filter); + scanOptionsBuilder.filter(filter); + } + + ScanOptions scanOptions = scanOptionsBuilder.build(); + + // Use Dataset level scan + try (LanceScanner scanner = dataset.newScan(scanOptions)) { + try (ArrowReader reader = scanner.scanBatches()) { + while (reader.loadNextBatch() && running && !isLimitReached()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + + // Convert to RowData and output + List rows = converter.toRowDataList(root); + synchronized (ctx.getCheckpointLock()) { + for (RowData row : rows) { + if (isLimitReached()) { + break; + } + ctx.collect(row); + emittedCount++; } + } } + } } - /** - * Check if Limit has been reached - */ - private boolean isLimitReached() { - return readLimit != null && emittedCount >= readLimit; + if (isLimitReached()) { + LOG.info("Reached Limit ({}), stop reading", readLimit); } + } - @Override - public void cancel() { - LOG.info("Cancel Lance data source"); - this.running = false; + /** Read single Fragment (without filter condition, but supports Limit) */ + private void readFragment(SourceContext ctx, Fragment fragment) throws Exception { + LOG.debug("Reading Fragment: {}", fragment.getId()); + + // Build scan options + ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder(); + + // Set batch size + scanOptionsBuilder.batchSize(options.getReadBatchSize()); + + // Set column filter + if (selectedColumns != null && selectedColumns.length > 0) { + scanOptionsBuilder.columns(Arrays.asList(selectedColumns)); } - @Override - public void close() throws Exception { - LOG.info("Closing Lance data source"); - - this.running = false; - - if (dataset != null) { - try { - dataset.close(); - } catch (Exception e) { - LOG.warn("Error closing Lance dataset", e); - } - dataset = null; - } - - if (allocator != null) { - try { - allocator.close(); - } catch (Exception e) { - LOG.warn("Error closing memory allocator", e); + // Note: Fragment level scan does not use filter, filter is only supported at Dataset level + + ScanOptions scanOptions = scanOptionsBuilder.build(); + + // Create Scanner and read data + try (LanceScanner scanner = fragment.newScan(scanOptions)) { + try (ArrowReader reader = scanner.scanBatches()) { + while (reader.loadNextBatch() && running && !isLimitReached()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + + // Convert to RowData and output + List rows = converter.toRowDataList(root); + synchronized (ctx.getCheckpointLock()) { + for (RowData row : rows) { + if (isLimitReached()) { + break; + } + ctx.collect(row); + emittedCount++; } - allocator = null; + } } - - super.close(); + } } + } - /** - * Get RowType - */ - public RowType getRowType() { - return rowType; - } + /** Check if Limit has been reached */ + private boolean isLimitReached() { + return readLimit != null && emittedCount >= readLimit; + } - /** - * Get configuration options - */ - public LanceOptions getOptions() { - return options; - } + @Override + public void cancel() { + LOG.info("Cancel Lance data source"); + this.running = false; + } - /** - * Get selected columns - */ - public String[] getSelectedColumns() { - return selectedColumns; + @Override + public void close() throws Exception { + LOG.info("Closing Lance data source"); + + this.running = false; + + if (dataset != null) { + try { + dataset.close(); + } catch (Exception e) { + LOG.warn("Error closing Lance dataset", e); + } + dataset = null; } - /** - * Builder pattern constructor - */ - public static Builder builder() { - return new Builder(); + if (allocator != null) { + try { + allocator.close(); + } catch (Exception e) { + LOG.warn("Error closing memory allocator", e); + } + allocator = null; } - /** - * LanceSource Builder - */ - public static class Builder { - private String path; - private int batchSize = 1024; - private List columns; - private String filter; - private Long limit; // Added - private RowType rowType; - - public Builder path(String path) { - this.path = path; - return this; - } + super.close(); + } - public Builder batchSize(int batchSize) { - this.batchSize = batchSize; - return this; - } + /** Get RowType */ + public RowType getRowType() { + return rowType; + } - public Builder columns(List columns) { - this.columns = columns; - return this; - } + /** Get configuration options */ + public LanceOptions getOptions() { + return options; + } - public Builder filter(String filter) { - this.filter = filter; - return this; - } + /** Get selected columns */ + public String[] getSelectedColumns() { + return selectedColumns; + } - public Builder limit(Long limit) { - this.limit = limit; - return this; - } + /** Builder pattern constructor */ + public static Builder builder() { + return new Builder(); + } - public Builder rowType(RowType rowType) { - this.rowType = rowType; - return this; - } + /** LanceSource Builder */ + public static class Builder { + private String path; + private int batchSize = 1024; + private List columns; + private String filter; + private Long limit; // Added + private RowType rowType; - public LanceSource build() { - if (path == null || path.isEmpty()) { - throw new IllegalArgumentException("Dataset path cannot be empty"); - } + public Builder path(String path) { + this.path = path; + return this; + } - LanceOptions options = LanceOptions.builder() - .path(path) - .readBatchSize(batchSize) - .readColumns(columns) - .readFilter(filter) - .readLimit(limit) - .build(); + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } - return new LanceSource(options, rowType); - } + public Builder columns(List columns) { + this.columns = columns; + return this; + } + + public Builder filter(String filter) { + this.filter = filter; + return this; + } + + public Builder limit(Long limit) { + this.limit = limit; + return this; + } + + public Builder rowType(RowType rowType) { + this.rowType = rowType; + return this; + } + + public LanceSource build() { + if (path == null || path.isEmpty()) { + throw new IllegalArgumentException("Dataset path cannot be empty"); + } + + LanceOptions options = + LanceOptions.builder() + .path(path) + .readBatchSize(batchSize) + .readColumns(columns) + .readFilter(filter) + .readLimit(limit) + .build(); + + return new LanceSource(options, rowType); } + } } diff --git a/src/main/java/org/apache/flink/connector/lance/LanceSplit.java b/src/main/java/org/apache/flink/connector/lance/LanceSplit.java index 1aec727..2aa1d2e 100644 --- a/src/main/java/org/apache/flink/connector/lance/LanceSplit.java +++ b/src/main/java/org/apache/flink/connector/lance/LanceSplit.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.core.io.InputSplit; @@ -25,97 +20,88 @@ /** * Lance data split. - * + * *

Represents a Fragment in Lance dataset, used for parallel data reading. */ public class LanceSplit implements InputSplit, Serializable { - private static final long serialVersionUID = 1L; - - /** - * Split number - */ - private final int splitNumber; - - /** - * Fragment ID - */ - private final int fragmentId; - - /** - * Dataset path - */ - private final String datasetPath; - - /** - * Row count in Fragment (estimated) - */ - private final long rowCount; - - /** - * Create LanceSplit - * - * @param splitNumber Split number - * @param fragmentId Fragment ID - * @param datasetPath Dataset path - * @param rowCount Row count - */ - public LanceSplit(int splitNumber, int fragmentId, String datasetPath, long rowCount) { - this.splitNumber = splitNumber; - this.fragmentId = fragmentId; - this.datasetPath = datasetPath; - this.rowCount = rowCount; - } - - @Override - public int getSplitNumber() { - return splitNumber; - } - - /** - * Get Fragment ID - */ - public int getFragmentId() { - return fragmentId; - } - - /** - * Get dataset path - */ - public String getDatasetPath() { - return datasetPath; - } - - /** - * Get row count - */ - public long getRowCount() { - return rowCount; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - LanceSplit that = (LanceSplit) o; - return splitNumber == that.splitNumber && - fragmentId == that.fragmentId && - rowCount == that.rowCount && - Objects.equals(datasetPath, that.datasetPath); - } - - @Override - public int hashCode() { - return Objects.hash(splitNumber, fragmentId, datasetPath, rowCount); - } - - @Override - public String toString() { - return "LanceSplit{" + - "splitNumber=" + splitNumber + - ", fragmentId=" + fragmentId + - ", datasetPath='" + datasetPath + '\'' + - ", rowCount=" + rowCount + - '}'; - } + private static final long serialVersionUID = 1L; + + /** Split number */ + private final int splitNumber; + + /** Fragment ID */ + private final int fragmentId; + + /** Dataset path */ + private final String datasetPath; + + /** Row count in Fragment (estimated) */ + private final long rowCount; + + /** + * Create LanceSplit + * + * @param splitNumber Split number + * @param fragmentId Fragment ID + * @param datasetPath Dataset path + * @param rowCount Row count + */ + public LanceSplit(int splitNumber, int fragmentId, String datasetPath, long rowCount) { + this.splitNumber = splitNumber; + this.fragmentId = fragmentId; + this.datasetPath = datasetPath; + this.rowCount = rowCount; + } + + @Override + public int getSplitNumber() { + return splitNumber; + } + + /** Get Fragment ID */ + public int getFragmentId() { + return fragmentId; + } + + /** Get dataset path */ + public String getDatasetPath() { + return datasetPath; + } + + /** Get row count */ + public long getRowCount() { + return rowCount; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LanceSplit that = (LanceSplit) o; + return splitNumber == that.splitNumber + && fragmentId == that.fragmentId + && rowCount == that.rowCount + && Objects.equals(datasetPath, that.datasetPath); + } + + @Override + public int hashCode() { + return Objects.hash(splitNumber, fragmentId, datasetPath, rowCount); + } + + @Override + public String toString() { + return "LanceSplit{" + + "splitNumber=" + + splitNumber + + ", fragmentId=" + + fragmentId + + ", datasetPath='" + + datasetPath + + '\'' + + ", rowCount=" + + rowCount + + '}'; + } } diff --git a/src/main/java/org/apache/flink/connector/lance/LanceVectorSearch.java b/src/main/java/org/apache/flink/connector/lance/LanceVectorSearch.java index faf0c5a..25018ae 100644 --- a/src/main/java/org/apache/flink/connector/lance/LanceVectorSearch.java +++ b/src/main/java/org/apache/flink/connector/lance/LanceVectorSearch.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.connector.lance.config.LanceOptions; @@ -49,10 +44,11 @@ /** * Lance vector search implementation. - * + * *

Supports KNN search with L2, Cosine, and Dot distance metrics. - * + * *

Usage example: + * *

{@code
  * LanceVectorSearch search = LanceVectorSearch.builder()
  *     .datasetPath("/path/to/dataset")
@@ -60,391 +56,368 @@
  *     .metricType(MetricType.L2)
  *     .nprobes(20)
  *     .build();
- * 
+ *
  * List results = search.search(queryVector, 10);
  * }
*/ public class LanceVectorSearch implements Closeable, Serializable { - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(LanceVectorSearch.class); - - private final String datasetPath; - private final String columnName; - private final MetricType metricType; - private final int nprobes; - private final int ef; - private final Integer refineFactor; - - private transient BufferAllocator allocator; - private transient Dataset dataset; - private transient RowType rowType; - private transient RowDataConverter converter; - - private LanceVectorSearch(Builder builder) { - this.datasetPath = builder.datasetPath; - this.columnName = builder.columnName; - this.metricType = builder.metricType; - this.nprobes = builder.nprobes; - this.ef = builder.ef; - this.refineFactor = builder.refineFactor; + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(LanceVectorSearch.class); + + private final String datasetPath; + private final String columnName; + private final MetricType metricType; + private final int nprobes; + private final int ef; + private final Integer refineFactor; + + private transient BufferAllocator allocator; + private transient Dataset dataset; + private transient RowType rowType; + private transient RowDataConverter converter; + + private LanceVectorSearch(Builder builder) { + this.datasetPath = builder.datasetPath; + this.columnName = builder.columnName; + this.metricType = builder.metricType; + this.nprobes = builder.nprobes; + this.ef = builder.ef; + this.refineFactor = builder.refineFactor; + } + + /** Open dataset connection */ + public void open() throws IOException { + LOG.info("Opening vector search, dataset: {}", datasetPath); + + this.allocator = new RootAllocator(Long.MAX_VALUE); + + try { + this.dataset = Dataset.open(datasetPath, allocator); + + // Get Schema and create converter + Schema arrowSchema = dataset.getSchema(); + this.rowType = LanceTypeConverter.toFlinkRowType(arrowSchema); + this.converter = new RowDataConverter(rowType); + + } catch (Exception e) { + throw new IOException("Cannot open dataset: " + datasetPath, e); } - - /** - * Open dataset connection - */ - public void open() throws IOException { - LOG.info("Opening vector search, dataset: {}", datasetPath); - - this.allocator = new RootAllocator(Long.MAX_VALUE); - - try { - this.dataset = Dataset.open(datasetPath, allocator); - - // Get Schema and create converter - Schema arrowSchema = dataset.getSchema(); - this.rowType = LanceTypeConverter.toFlinkRowType(arrowSchema); - this.converter = new RowDataConverter(rowType); - - } catch (Exception e) { - throw new IOException("Cannot open dataset: " + datasetPath, e); - } + } + + /** + * Execute vector search + * + * @param queryVector Query vector + * @param k Number of nearest neighbors to return + * @return List of search results + */ + public List search(float[] queryVector, int k) throws IOException { + return search(queryVector, k, null); + } + + /** + * Execute vector search (with filter condition) + * + * @param queryVector Query vector + * @param k Number of nearest neighbors to return + * @param filter Filter condition (SQL WHERE syntax) + * @return List of search results + */ + public List search(float[] queryVector, int k, String filter) throws IOException { + if (dataset == null) { + open(); } - /** - * Execute vector search - * - * @param queryVector Query vector - * @param k Number of nearest neighbors to return - * @return List of search results - */ - public List search(float[] queryVector, int k) throws IOException { - return search(queryVector, k, null); - } + LOG.debug("Executing vector search, k={}, vector dimension={}", k, queryVector.length); - /** - * Execute vector search (with filter condition) - * - * @param queryVector Query vector - * @param k Number of nearest neighbors to return - * @param filter Filter condition (SQL WHERE syntax) - * @return List of search results - */ - public List search(float[] queryVector, int k, String filter) throws IOException { - if (dataset == null) { - open(); - } - - LOG.debug("Executing vector search, k={}, vector dimension={}", k, queryVector.length); - - // Validate query vector - validateQueryVector(queryVector); - - List results = new ArrayList<>(); - - try { - // Build vector query - Query.Builder queryBuilder = new Query.Builder() - .setColumn(columnName) - .setKey(queryVector) - .setK(k) - .setNprobes(nprobes) - .setDistanceType(toDistanceType(metricType)) - .setUseIndex(true); - - if (ef > 0) { - queryBuilder.setEf(ef); - } - - if (refineFactor != null && refineFactor > 0) { - queryBuilder.setRefineFactor(refineFactor); - } - - Query query = queryBuilder.build(); - - // Build scan options - ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder() - .nearest(query) - .withRowId(true); - - if (filter != null && !filter.isEmpty()) { - scanOptionsBuilder.filter(filter); - } - - ScanOptions scanOptions = scanOptionsBuilder.build(); - - // Execute search - try (LanceScanner scanner = dataset.newScan(scanOptions)) { - try (ArrowReader reader = scanner.scanBatches()) { - while (reader.loadNextBatch()) { - VectorSchemaRoot root = reader.getVectorSchemaRoot(); - - // Convert to RowData - List rows = converter.toRowDataList(root); - - // Try to get distance score (if _distance column exists) - Float8Vector distanceVector = null; - try { - distanceVector = (Float8Vector) root.getVector("_distance"); - } catch (Exception e) { - // _distance column may not exist - } - - for (int i = 0; i < rows.size(); i++) { - double distance = 0.0; - if (distanceVector != null && !distanceVector.isNull(i)) { - distance = distanceVector.get(i); - } - results.add(new SearchResult(rows.get(i), distance)); - } - } - } - } - - LOG.debug("Search completed, returned {} results", results.size()); - return results; - - } catch (Exception e) { - throw new IOException("Vector search failed", e); - } - } + // Validate query vector + validateQueryVector(queryVector); - /** - * Execute vector search (return RowData list) - * - * @param queryVector Query vector - * @param k Number of nearest neighbors to return - * @return RowData list - */ - public List searchRowData(float[] queryVector, int k) throws IOException { - List results = search(queryVector, k); - List rowDataList = new ArrayList<>(results.size()); - - for (SearchResult result : results) { - // Append distance score to RowData - GenericRowData rowWithDistance = new GenericRowData(rowType.getFieldCount() + 1); - RowData originalRow = result.getRowData(); - - for (int i = 0; i < rowType.getFieldCount(); i++) { - rowWithDistance.setField(i, getFieldValue(originalRow, i)); - } - rowWithDistance.setField(rowType.getFieldCount(), result.getDistance()); - - rowDataList.add(rowWithDistance); - } - - return rowDataList; - } + List results = new ArrayList<>(); - /** - * Get field value from RowData - */ - private Object getFieldValue(RowData rowData, int index) { - if (rowData.isNullAt(index)) { - return null; - } - - // Simplified handling, should get based on field type in practice - if (rowData instanceof GenericRowData) { - return ((GenericRowData) rowData).getField(index); - } - - return null; - } + try { + // Build vector query + Query.Builder queryBuilder = + new Query.Builder() + .setColumn(columnName) + .setKey(queryVector) + .setK(k) + .setNprobes(nprobes) + .setDistanceType(toDistanceType(metricType)) + .setUseIndex(true); - /** - * Validate query vector - */ - private void validateQueryVector(float[] queryVector) throws IOException { - if (queryVector == null || queryVector.length == 0) { - throw new IllegalArgumentException("Query vector cannot be empty"); - } - - // Check for NaN or Infinity values - for (float value : queryVector) { - if (Float.isNaN(value) || Float.isInfinite(value)) { - throw new IllegalArgumentException("Query vector contains invalid values (NaN or Infinity)"); - } - } - } + if (ef > 0) { + queryBuilder.setEf(ef); + } - /** - * Convert distance metric type - */ - private DistanceType toDistanceType(MetricType metricType) { - switch (metricType) { - case L2: - return DistanceType.L2; - case COSINE: - return DistanceType.Cosine; - case DOT: - return DistanceType.Dot; - default: - return DistanceType.L2; - } - } + if (refineFactor != null && refineFactor > 0) { + queryBuilder.setRefineFactor(refineFactor); + } - @Override - public void close() throws IOException { - if (dataset != null) { + Query query = queryBuilder.build(); + + // Build scan options + ScanOptions.Builder scanOptionsBuilder = + new ScanOptions.Builder().nearest(query).withRowId(true); + + if (filter != null && !filter.isEmpty()) { + scanOptionsBuilder.filter(filter); + } + + ScanOptions scanOptions = scanOptionsBuilder.build(); + + // Execute search + try (LanceScanner scanner = dataset.newScan(scanOptions)) { + try (ArrowReader reader = scanner.scanBatches()) { + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + + // Convert to RowData + List rows = converter.toRowDataList(root); + + // Try to get distance score (if _distance column exists) + Float8Vector distanceVector = null; try { - dataset.close(); + distanceVector = (Float8Vector) root.getVector("_distance"); } catch (Exception e) { - LOG.warn("Failed to close dataset", e); + // _distance column may not exist } - dataset = null; - } - - if (allocator != null) { - try { - allocator.close(); - } catch (Exception e) { - LOG.warn("Failed to close allocator", e); + + for (int i = 0; i < rows.size(); i++) { + double distance = 0.0; + if (distanceVector != null && !distanceVector.isNull(i)) { + distance = distanceVector.get(i); + } + results.add(new SearchResult(rows.get(i), distance)); } - allocator = null; + } } + } + + LOG.debug("Search completed, returned {} results", results.size()); + return results; + + } catch (Exception e) { + throw new IOException("Vector search failed", e); } + } + + /** + * Execute vector search (return RowData list) + * + * @param queryVector Query vector + * @param k Number of nearest neighbors to return + * @return RowData list + */ + public List searchRowData(float[] queryVector, int k) throws IOException { + List results = search(queryVector, k); + List rowDataList = new ArrayList<>(results.size()); + + for (SearchResult result : results) { + // Append distance score to RowData + GenericRowData rowWithDistance = new GenericRowData(rowType.getFieldCount() + 1); + RowData originalRow = result.getRowData(); + + for (int i = 0; i < rowType.getFieldCount(); i++) { + rowWithDistance.setField(i, getFieldValue(originalRow, i)); + } + rowWithDistance.setField(rowType.getFieldCount(), result.getDistance()); + + rowDataList.add(rowWithDistance); + } + + return rowDataList; + } - /** - * Get RowType - */ - public RowType getRowType() { - return rowType; + /** Get field value from RowData */ + private Object getFieldValue(RowData rowData, int index) { + if (rowData.isNullAt(index)) { + return null; } - /** - * Create builder - */ - public static Builder builder() { - return new Builder(); + // Simplified handling, should get based on field type in practice + if (rowData instanceof GenericRowData) { + return ((GenericRowData) rowData).getField(index); } - /** - * Create vector searcher from LanceOptions - */ - public static LanceVectorSearch fromOptions(LanceOptions options) { - return builder() - .datasetPath(options.getPath()) - .columnName(options.getVectorColumn()) - .metricType(options.getVectorMetric()) - .nprobes(options.getVectorNprobes()) - .ef(options.getVectorEf()) - .refineFactor(options.getVectorRefineFactor()) - .build(); + return null; + } + + /** Validate query vector */ + private void validateQueryVector(float[] queryVector) throws IOException { + if (queryVector == null || queryVector.length == 0) { + throw new IllegalArgumentException("Query vector cannot be empty"); } - /** - * Builder - */ - public static class Builder { - private String datasetPath; - private String columnName; - private MetricType metricType = MetricType.L2; - private int nprobes = 20; - private int ef = 100; - private Integer refineFactor; - - public Builder datasetPath(String datasetPath) { - this.datasetPath = datasetPath; - return this; - } + // Check for NaN or Infinity values + for (float value : queryVector) { + if (Float.isNaN(value) || Float.isInfinite(value)) { + throw new IllegalArgumentException( + "Query vector contains invalid values (NaN or Infinity)"); + } + } + } + + /** Convert distance metric type */ + private DistanceType toDistanceType(MetricType metricType) { + switch (metricType) { + case L2: + return DistanceType.L2; + case COSINE: + return DistanceType.Cosine; + case DOT: + return DistanceType.Dot; + default: + return DistanceType.L2; + } + } + + @Override + public void close() throws IOException { + if (dataset != null) { + try { + dataset.close(); + } catch (Exception e) { + LOG.warn("Failed to close dataset", e); + } + dataset = null; + } - public Builder columnName(String columnName) { - this.columnName = columnName; - return this; - } + if (allocator != null) { + try { + allocator.close(); + } catch (Exception e) { + LOG.warn("Failed to close allocator", e); + } + allocator = null; + } + } + + /** Get RowType */ + public RowType getRowType() { + return rowType; + } + + /** Create builder */ + public static Builder builder() { + return new Builder(); + } + + /** Create vector searcher from LanceOptions */ + public static LanceVectorSearch fromOptions(LanceOptions options) { + return builder() + .datasetPath(options.getPath()) + .columnName(options.getVectorColumn()) + .metricType(options.getVectorMetric()) + .nprobes(options.getVectorNprobes()) + .ef(options.getVectorEf()) + .refineFactor(options.getVectorRefineFactor()) + .build(); + } + + /** Builder */ + public static class Builder { + private String datasetPath; + private String columnName; + private MetricType metricType = MetricType.L2; + private int nprobes = 20; + private int ef = 100; + private Integer refineFactor; + + public Builder datasetPath(String datasetPath) { + this.datasetPath = datasetPath; + return this; + } - public Builder metricType(MetricType metricType) { - this.metricType = metricType; - return this; - } + public Builder columnName(String columnName) { + this.columnName = columnName; + return this; + } - public Builder nprobes(int nprobes) { - this.nprobes = nprobes; - return this; - } + public Builder metricType(MetricType metricType) { + this.metricType = metricType; + return this; + } - public Builder ef(int ef) { - this.ef = ef; - return this; - } + public Builder nprobes(int nprobes) { + this.nprobes = nprobes; + return this; + } - public Builder refineFactor(Integer refineFactor) { - this.refineFactor = refineFactor; - return this; - } + public Builder ef(int ef) { + this.ef = ef; + return this; + } - public LanceVectorSearch build() { - validate(); - return new LanceVectorSearch(this); - } + public Builder refineFactor(Integer refineFactor) { + this.refineFactor = refineFactor; + return this; + } - private void validate() { - if (datasetPath == null || datasetPath.isEmpty()) { - throw new IllegalArgumentException("Dataset path cannot be empty"); - } - if (columnName == null || columnName.isEmpty()) { - throw new IllegalArgumentException("Column name cannot be empty"); - } - if (nprobes <= 0) { - throw new IllegalArgumentException("nprobes must be greater than 0"); - } - } + public LanceVectorSearch build() { + validate(); + return new LanceVectorSearch(this); } - /** - * Search result - */ - public static class SearchResult implements Serializable { - private static final long serialVersionUID = 1L; + private void validate() { + if (datasetPath == null || datasetPath.isEmpty()) { + throw new IllegalArgumentException("Dataset path cannot be empty"); + } + if (columnName == null || columnName.isEmpty()) { + throw new IllegalArgumentException("Column name cannot be empty"); + } + if (nprobes <= 0) { + throw new IllegalArgumentException("nprobes must be greater than 0"); + } + } + } - private final RowData rowData; - private final double distance; + /** Search result */ + public static class SearchResult implements Serializable { + private static final long serialVersionUID = 1L; - public SearchResult(RowData rowData, double distance) { - this.rowData = rowData; - this.distance = distance; - } + private final RowData rowData; + private final double distance; - public RowData getRowData() { - return rowData; - } + public SearchResult(RowData rowData, double distance) { + this.rowData = rowData; + this.distance = distance; + } - public double getDistance() { - return distance; - } + public RowData getRowData() { + return rowData; + } - /** - * Get similarity score (inverse or negative of distance, depending on distance type) - */ - public double getSimilarity() { - if (distance == 0) { - return 1.0; - } - // For L2 distance, use 1 / (1 + distance) as similarity - return 1.0 / (1.0 + distance); - } + public double getDistance() { + return distance; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - SearchResult that = (SearchResult) o; - return Double.compare(that.distance, distance) == 0 && - Objects.equals(rowData, that.rowData); - } + /** Get similarity score (inverse or negative of distance, depending on distance type) */ + public double getSimilarity() { + if (distance == 0) { + return 1.0; + } + // For L2 distance, use 1 / (1 + distance) as similarity + return 1.0 / (1.0 + distance); + } - @Override - public int hashCode() { - return Objects.hash(rowData, distance); - } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SearchResult that = (SearchResult) o; + return Double.compare(that.distance, distance) == 0 && Objects.equals(rowData, that.rowData); + } - @Override - public String toString() { - return "SearchResult{" + - "rowData=" + rowData + - ", distance=" + distance + - '}'; - } + @Override + public int hashCode() { + return Objects.hash(rowData, distance); + } + + @Override + public String toString() { + return "SearchResult{" + "rowData=" + rowData + ", distance=" + distance + '}'; } + } } diff --git a/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateExecutor.java b/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateExecutor.java index 5509e14..d51bc5c 100644 --- a/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateExecutor.java +++ b/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateExecutor.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.aggregate; import org.apache.flink.table.data.DecimalData; @@ -41,518 +36,493 @@ /** * Aggregate executor. - * - *

Executes aggregate calculations at data source side, supports COUNT, SUM, AVG, MIN, MAX and other aggregate functions. + * + *

Executes aggregate calculations at data source side, supports COUNT, SUM, AVG, MIN, MAX and + * other aggregate functions. */ public class AggregateExecutor implements Serializable { - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(AggregateExecutor.class); + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(AggregateExecutor.class); - private final AggregateInfo aggregateInfo; - private final RowType sourceRowType; + private final AggregateInfo aggregateInfo; + private final RowType sourceRowType; - // Aggregate state (by group key) - private transient Map aggregateStates; - private transient boolean initialized; + // Aggregate state (by group key) + private transient Map aggregateStates; + private transient boolean initialized; - public AggregateExecutor(AggregateInfo aggregateInfo, RowType sourceRowType) { - this.aggregateInfo = aggregateInfo; - this.sourceRowType = sourceRowType; - } + public AggregateExecutor(AggregateInfo aggregateInfo, RowType sourceRowType) { + this.aggregateInfo = aggregateInfo; + this.sourceRowType = sourceRowType; + } + + /** Initialize aggregate executor */ + public void init() { + this.aggregateStates = new HashMap<>(); + this.initialized = true; + LOG.info("Initialized aggregate executor: {}", aggregateInfo); + } - /** - * Initialize aggregate executor - */ - public void init() { - this.aggregateStates = new HashMap<>(); - this.initialized = true; - LOG.info("Initialized aggregate executor: {}", aggregateInfo); + /** Accumulate a row to aggregate state */ + public void accumulate(RowData row) { + if (!initialized) { + init(); } - /** - * Accumulate a row to aggregate state - */ - public void accumulate(RowData row) { - if (!initialized) { - init(); + // Extract group key + GroupKey groupKey = extractGroupKey(row); + + // Get or create aggregate state + AggregateState state = + aggregateStates.computeIfAbsent( + groupKey, k -> new AggregateState(aggregateInfo.getAggregateCalls().size())); + + // Update state for each aggregate function + List calls = aggregateInfo.getAggregateCalls(); + for (int i = 0; i < calls.size(); i++) { + AggregateInfo.AggregateCall call = calls.get(i); + accumulateCall(state, i, call, row); + } + } + + /** Accumulate single aggregate function */ + private void accumulateCall( + AggregateState state, int index, AggregateInfo.AggregateCall call, RowData row) { + switch (call.getFunction()) { + case COUNT: + if (call.isCountStar()) { + // COUNT(*) + state.incrementCount(index); + } else { + // COUNT(column) - only count non-NULL values + int fieldIndex = getFieldIndex(call.getColumn()); + if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { + state.incrementCount(index); + } + } + break; + + case COUNT_DISTINCT: + if (call.getColumn() != null) { + int fieldIndex = getFieldIndex(call.getColumn()); + if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { + Object value = extractValue(row, fieldIndex); + state.addDistinctValue(index, value); + } + } + break; + + case SUM: + if (call.getColumn() != null) { + int fieldIndex = getFieldIndex(call.getColumn()); + if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { + Number value = extractNumericValue(row, fieldIndex); + if (value != null) { + state.addSum(index, value.doubleValue()); + } + } } + break; - // Extract group key - GroupKey groupKey = extractGroupKey(row); - - // Get or create aggregate state - AggregateState state = aggregateStates.computeIfAbsent(groupKey, - k -> new AggregateState(aggregateInfo.getAggregateCalls().size())); - - // Update state for each aggregate function - List calls = aggregateInfo.getAggregateCalls(); - for (int i = 0; i < calls.size(); i++) { - AggregateInfo.AggregateCall call = calls.get(i); - accumulateCall(state, i, call, row); + case AVG: + if (call.getColumn() != null) { + int fieldIndex = getFieldIndex(call.getColumn()); + if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { + Number value = extractNumericValue(row, fieldIndex); + if (value != null) { + state.addForAvg(index, value.doubleValue()); + } + } } - } + break; - /** - * Accumulate single aggregate function - */ - private void accumulateCall(AggregateState state, int index, - AggregateInfo.AggregateCall call, RowData row) { - switch (call.getFunction()) { - case COUNT: - if (call.isCountStar()) { - // COUNT(*) - state.incrementCount(index); - } else { - // COUNT(column) - only count non-NULL values - int fieldIndex = getFieldIndex(call.getColumn()); - if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { - state.incrementCount(index); - } - } - break; - - case COUNT_DISTINCT: - if (call.getColumn() != null) { - int fieldIndex = getFieldIndex(call.getColumn()); - if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { - Object value = extractValue(row, fieldIndex); - state.addDistinctValue(index, value); - } - } - break; - - case SUM: - if (call.getColumn() != null) { - int fieldIndex = getFieldIndex(call.getColumn()); - if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { - Number value = extractNumericValue(row, fieldIndex); - if (value != null) { - state.addSum(index, value.doubleValue()); - } - } - } - break; - - case AVG: - if (call.getColumn() != null) { - int fieldIndex = getFieldIndex(call.getColumn()); - if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { - Number value = extractNumericValue(row, fieldIndex); - if (value != null) { - state.addForAvg(index, value.doubleValue()); - } - } - } - break; - - case MIN: - if (call.getColumn() != null) { - int fieldIndex = getFieldIndex(call.getColumn()); - if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { - Comparable value = extractComparableValue(row, fieldIndex); - if (value != null) { - state.updateMin(index, value); - } - } - } - break; - - case MAX: - if (call.getColumn() != null) { - int fieldIndex = getFieldIndex(call.getColumn()); - if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { - Comparable value = extractComparableValue(row, fieldIndex); - if (value != null) { - state.updateMax(index, value); - } - } - } - break; + case MIN: + if (call.getColumn() != null) { + int fieldIndex = getFieldIndex(call.getColumn()); + if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { + Comparable value = extractComparableValue(row, fieldIndex); + if (value != null) { + state.updateMin(index, value); + } + } } - } + break; - /** - * Get aggregate results - */ - public List getResults() { - if (!initialized || aggregateStates.isEmpty()) { - // If no data, return default aggregate result - return getDefaultResults(); + case MAX: + if (call.getColumn() != null) { + int fieldIndex = getFieldIndex(call.getColumn()); + if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) { + Comparable value = extractComparableValue(row, fieldIndex); + if (value != null) { + state.updateMax(index, value); + } + } } + break; + } + } - List results = new ArrayList<>(); - List calls = aggregateInfo.getAggregateCalls(); - List groupByCols = aggregateInfo.getGroupByColumns(); + /** Get aggregate results */ + public List getResults() { + if (!initialized || aggregateStates.isEmpty()) { + // If no data, return default aggregate result + return getDefaultResults(); + } - for (Map.Entry entry : aggregateStates.entrySet()) { - GroupKey groupKey = entry.getKey(); - AggregateState state = entry.getValue(); + List results = new ArrayList<>(); + List calls = aggregateInfo.getAggregateCalls(); + List groupByCols = aggregateInfo.getGroupByColumns(); - // Create result row: group columns + aggregate columns - int totalFields = groupByCols.size() + calls.size(); - GenericRowData resultRow = new GenericRowData(totalFields); + for (Map.Entry entry : aggregateStates.entrySet()) { + GroupKey groupKey = entry.getKey(); + AggregateState state = entry.getValue(); - // Fill group columns - for (int i = 0; i < groupByCols.size(); i++) { - resultRow.setField(i, groupKey.getValues()[i]); - } + // Create result row: group columns + aggregate columns + int totalFields = groupByCols.size() + calls.size(); + GenericRowData resultRow = new GenericRowData(totalFields); - // Fill aggregate results - for (int i = 0; i < calls.size(); i++) { - AggregateInfo.AggregateCall call = calls.get(i); - Object aggResult = getAggregateResult(state, i, call); - resultRow.setField(groupByCols.size() + i, aggResult); - } + // Fill group columns + for (int i = 0; i < groupByCols.size(); i++) { + resultRow.setField(i, groupKey.getValues()[i]); + } - results.add(resultRow); - } + // Fill aggregate results + for (int i = 0; i < calls.size(); i++) { + AggregateInfo.AggregateCall call = calls.get(i); + Object aggResult = getAggregateResult(state, i, call); + resultRow.setField(groupByCols.size() + i, aggResult); + } - LOG.info("Aggregate execution completed, generated {} result rows", results.size()); - return results; + results.add(resultRow); } - /** - * Get default aggregate result when no data - */ - private List getDefaultResults() { - // If has GROUP BY, no data means no result - if (aggregateInfo.hasGroupBy()) { - return new ArrayList<>(); - } - - // No GROUP BY, return default aggregate values - List calls = aggregateInfo.getAggregateCalls(); - GenericRowData resultRow = new GenericRowData(calls.size()); - - for (int i = 0; i < calls.size(); i++) { - AggregateInfo.AggregateCall call = calls.get(i); - switch (call.getFunction()) { - case COUNT: - case COUNT_DISTINCT: - resultRow.setField(i, 0L); - break; - default: - resultRow.setField(i, null); - break; - } - } + LOG.info("Aggregate execution completed, generated {} result rows", results.size()); + return results; + } - List results = new ArrayList<>(); - results.add(resultRow); - return results; - } - - /** - * Get single aggregate function result - */ - private Object getAggregateResult(AggregateState state, int index, - AggregateInfo.AggregateCall call) { - switch (call.getFunction()) { - case COUNT: - return state.getCount(index); - case COUNT_DISTINCT: - return (long) state.getDistinctCount(index); - case SUM: - Double sum = state.getSum(index); - return sum != null ? sum : null; - case AVG: - Double avg = state.getAvg(index); - return avg != null ? avg : null; - case MIN: - return state.getMin(index); - case MAX: - return state.getMax(index); - default: - return null; - } + /** Get default aggregate result when no data */ + private List getDefaultResults() { + // If has GROUP BY, no data means no result + if (aggregateInfo.hasGroupBy()) { + return new ArrayList<>(); } - /** - * Extract group key - */ - private GroupKey extractGroupKey(RowData row) { - List groupByCols = aggregateInfo.getGroupByColumns(); - if (groupByCols.isEmpty()) { - return GroupKey.EMPTY; - } - - Object[] keyValues = new Object[groupByCols.size()]; - for (int i = 0; i < groupByCols.size(); i++) { - int fieldIndex = getFieldIndex(groupByCols.get(i)); - if (fieldIndex >= 0) { - keyValues[i] = extractValue(row, fieldIndex); - } - } - return new GroupKey(keyValues); + // No GROUP BY, return default aggregate values + List calls = aggregateInfo.getAggregateCalls(); + GenericRowData resultRow = new GenericRowData(calls.size()); + + for (int i = 0; i < calls.size(); i++) { + AggregateInfo.AggregateCall call = calls.get(i); + switch (call.getFunction()) { + case COUNT: + case COUNT_DISTINCT: + resultRow.setField(i, 0L); + break; + default: + resultRow.setField(i, null); + break; + } } - /** - * Get field index - */ - private int getFieldIndex(String columnName) { - List fieldNames = sourceRowType.getFieldNames(); - return fieldNames.indexOf(columnName); + List results = new ArrayList<>(); + results.add(resultRow); + return results; + } + + /** Get single aggregate function result */ + private Object getAggregateResult( + AggregateState state, int index, AggregateInfo.AggregateCall call) { + switch (call.getFunction()) { + case COUNT: + return state.getCount(index); + case COUNT_DISTINCT: + return (long) state.getDistinctCount(index); + case SUM: + Double sum = state.getSum(index); + return sum != null ? sum : null; + case AVG: + Double avg = state.getAvg(index); + return avg != null ? avg : null; + case MIN: + return state.getMin(index); + case MAX: + return state.getMax(index); + default: + return null; } + } - /** - * Extract field value - */ - private Object extractValue(RowData row, int fieldIndex) { - if (row.isNullAt(fieldIndex)) { - return null; - } + /** Extract group key */ + private GroupKey extractGroupKey(RowData row) { + List groupByCols = aggregateInfo.getGroupByColumns(); + if (groupByCols.isEmpty()) { + return GroupKey.EMPTY; + } - LogicalType fieldType = sourceRowType.getTypeAt(fieldIndex); - switch (fieldType.getTypeRoot()) { - case BOOLEAN: - return row.getBoolean(fieldIndex); - case TINYINT: - return row.getByte(fieldIndex); - case SMALLINT: - return row.getShort(fieldIndex); - case INTEGER: - return row.getInt(fieldIndex); - case BIGINT: - return row.getLong(fieldIndex); - case FLOAT: - return row.getFloat(fieldIndex); - case DOUBLE: - return row.getDouble(fieldIndex); - case CHAR: - case VARCHAR: - // Keep StringData type for group key and result output - return row.getString(fieldIndex); - case DECIMAL: - DecimalType decType = (DecimalType) fieldType; - DecimalData decData = row.getDecimal(fieldIndex, decType.getPrecision(), decType.getScale()); - return decData != null ? decData.toBigDecimal() : null; - default: - return null; - } + Object[] keyValues = new Object[groupByCols.size()]; + for (int i = 0; i < groupByCols.size(); i++) { + int fieldIndex = getFieldIndex(groupByCols.get(i)); + if (fieldIndex >= 0) { + keyValues[i] = extractValue(row, fieldIndex); + } + } + return new GroupKey(keyValues); + } + + /** Get field index */ + private int getFieldIndex(String columnName) { + List fieldNames = sourceRowType.getFieldNames(); + return fieldNames.indexOf(columnName); + } + + /** Extract field value */ + private Object extractValue(RowData row, int fieldIndex) { + if (row.isNullAt(fieldIndex)) { + return null; } - /** - * Extract numeric type field value - */ - private Number extractNumericValue(RowData row, int fieldIndex) { - if (row.isNullAt(fieldIndex)) { - return null; - } + LogicalType fieldType = sourceRowType.getTypeAt(fieldIndex); + switch (fieldType.getTypeRoot()) { + case BOOLEAN: + return row.getBoolean(fieldIndex); + case TINYINT: + return row.getByte(fieldIndex); + case SMALLINT: + return row.getShort(fieldIndex); + case INTEGER: + return row.getInt(fieldIndex); + case BIGINT: + return row.getLong(fieldIndex); + case FLOAT: + return row.getFloat(fieldIndex); + case DOUBLE: + return row.getDouble(fieldIndex); + case CHAR: + case VARCHAR: + // Keep StringData type for group key and result output + return row.getString(fieldIndex); + case DECIMAL: + DecimalType decType = (DecimalType) fieldType; + DecimalData decData = + row.getDecimal(fieldIndex, decType.getPrecision(), decType.getScale()); + return decData != null ? decData.toBigDecimal() : null; + default: + return null; + } + } - LogicalType fieldType = sourceRowType.getTypeAt(fieldIndex); - switch (fieldType.getTypeRoot()) { - case TINYINT: - return row.getByte(fieldIndex); - case SMALLINT: - return row.getShort(fieldIndex); - case INTEGER: - return row.getInt(fieldIndex); - case BIGINT: - return row.getLong(fieldIndex); - case FLOAT: - return row.getFloat(fieldIndex); - case DOUBLE: - return row.getDouble(fieldIndex); - case DECIMAL: - DecimalType decType = (DecimalType) fieldType; - DecimalData decData = row.getDecimal(fieldIndex, decType.getPrecision(), decType.getScale()); - return decData != null ? decData.toBigDecimal() : null; - default: - return null; - } + /** Extract numeric type field value */ + private Number extractNumericValue(RowData row, int fieldIndex) { + if (row.isNullAt(fieldIndex)) { + return null; } - /** - * Extract comparable type field value - */ - @SuppressWarnings("unchecked") - private Comparable extractComparableValue(RowData row, int fieldIndex) { - Object value = extractValue(row, fieldIndex); - if (value instanceof Comparable) { - return (Comparable) value; - } + LogicalType fieldType = sourceRowType.getTypeAt(fieldIndex); + switch (fieldType.getTypeRoot()) { + case TINYINT: + return row.getByte(fieldIndex); + case SMALLINT: + return row.getShort(fieldIndex); + case INTEGER: + return row.getInt(fieldIndex); + case BIGINT: + return row.getLong(fieldIndex); + case FLOAT: + return row.getFloat(fieldIndex); + case DOUBLE: + return row.getDouble(fieldIndex); + case DECIMAL: + DecimalType decType = (DecimalType) fieldType; + DecimalData decData = + row.getDecimal(fieldIndex, decType.getPrecision(), decType.getScale()); + return decData != null ? decData.toBigDecimal() : null; + default: return null; } + } + + /** Extract comparable type field value */ + @SuppressWarnings("unchecked") + private Comparable extractComparableValue(RowData row, int fieldIndex) { + Object value = extractValue(row, fieldIndex); + if (value instanceof Comparable) { + return (Comparable) value; + } + return null; + } - /** - * Reset aggregate state - */ - public void reset() { - if (aggregateStates != null) { - aggregateStates.clear(); - } + /** Reset aggregate state */ + public void reset() { + if (aggregateStates != null) { + aggregateStates.clear(); } + } - /** - * Group key - */ - private static class GroupKey implements Serializable { - private static final long serialVersionUID = 1L; - - static final GroupKey EMPTY = new GroupKey(new Object[0]); + /** Group key */ + private static class GroupKey implements Serializable { + private static final long serialVersionUID = 1L; - private final Object[] values; - private final int hashCode; + static final GroupKey EMPTY = new GroupKey(new Object[0]); - GroupKey(Object[] values) { - this.values = values; - this.hashCode = Objects.hash((Object[]) values); - } + private final Object[] values; + private final int hashCode; - Object[] getValues() { - return values; - } + GroupKey(Object[] values) { + this.values = values; + this.hashCode = Objects.hash((Object[]) values); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - GroupKey groupKey = (GroupKey) o; - return java.util.Arrays.equals(values, groupKey.values); - } + Object[] getValues() { + return values; + } - @Override - public int hashCode() { - return hashCode; - } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + GroupKey groupKey = (GroupKey) o; + return java.util.Arrays.equals(values, groupKey.values); } - /** - * Aggregate state - */ - private static class AggregateState implements Serializable { - private static final long serialVersionUID = 1L; - - private final long[] counts; - private final double[] sums; - private final long[] avgCounts; // For calculating AVG - private final Comparable[] mins; - private final Comparable[] maxs; - private final Set[] distinctSets; - - @SuppressWarnings("unchecked") - AggregateState(int numAggregates) { - this.counts = new long[numAggregates]; - this.sums = new double[numAggregates]; - this.avgCounts = new long[numAggregates]; - this.mins = new Comparable[numAggregates]; - this.maxs = new Comparable[numAggregates]; - this.distinctSets = new Set[numAggregates]; - } + @Override + public int hashCode() { + return hashCode; + } + } - void incrementCount(int index) { - counts[index]++; - } + /** Aggregate state */ + private static class AggregateState implements Serializable { + private static final long serialVersionUID = 1L; - long getCount(int index) { - return counts[index]; - } + private final long[] counts; + private final double[] sums; + private final long[] avgCounts; // For calculating AVG + private final Comparable[] mins; + private final Comparable[] maxs; + private final Set[] distinctSets; - void addDistinctValue(int index, Object value) { - if (distinctSets[index] == null) { - distinctSets[index] = new HashSet<>(); - } - distinctSets[index].add(value); - } + @SuppressWarnings("unchecked") + AggregateState(int numAggregates) { + this.counts = new long[numAggregates]; + this.sums = new double[numAggregates]; + this.avgCounts = new long[numAggregates]; + this.mins = new Comparable[numAggregates]; + this.maxs = new Comparable[numAggregates]; + this.distinctSets = new Set[numAggregates]; + } - int getDistinctCount(int index) { - return distinctSets[index] != null ? distinctSets[index].size() : 0; - } + void incrementCount(int index) { + counts[index]++; + } - void addSum(int index, double value) { - sums[index] += value; - counts[index]++; // Mark as has value - } + long getCount(int index) { + return counts[index]; + } - Double getSum(int index) { - return counts[index] > 0 ? sums[index] : null; - } + void addDistinctValue(int index, Object value) { + if (distinctSets[index] == null) { + distinctSets[index] = new HashSet<>(); + } + distinctSets[index].add(value); + } - void addForAvg(int index, double value) { - sums[index] += value; - avgCounts[index]++; - } + int getDistinctCount(int index) { + return distinctSets[index] != null ? distinctSets[index].size() : 0; + } - Double getAvg(int index) { - return avgCounts[index] > 0 ? sums[index] / avgCounts[index] : null; - } + void addSum(int index, double value) { + sums[index] += value; + counts[index]++; // Mark as has value + } - @SuppressWarnings({"unchecked", "rawtypes"}) - void updateMin(int index, Comparable value) { - if (mins[index] == null || ((Comparable) value).compareTo(mins[index]) < 0) { - mins[index] = value; - } - } + Double getSum(int index) { + return counts[index] > 0 ? sums[index] : null; + } - Comparable getMin(int index) { - return mins[index]; - } + void addForAvg(int index, double value) { + sums[index] += value; + avgCounts[index]++; + } - @SuppressWarnings({"unchecked", "rawtypes"}) - void updateMax(int index, Comparable value) { - if (maxs[index] == null || ((Comparable) value).compareTo(maxs[index]) > 0) { - maxs[index] = value; - } - } + Double getAvg(int index) { + return avgCounts[index] > 0 ? sums[index] / avgCounts[index] : null; + } - Comparable getMax(int index) { - return maxs[index]; - } + @SuppressWarnings({"unchecked", "rawtypes"}) + void updateMin(int index, Comparable value) { + if (mins[index] == null || ((Comparable) value).compareTo(mins[index]) < 0) { + mins[index] = value; + } + } + + Comparable getMin(int index) { + return mins[index]; } - /** - * Build aggregate result RowType - */ - public RowType buildResultRowType() { - List groupByCols = aggregateInfo.getGroupByColumns(); - List calls = aggregateInfo.getAggregateCalls(); + @SuppressWarnings({"unchecked", "rawtypes"}) + void updateMax(int index, Comparable value) { + if (maxs[index] == null || ((Comparable) value).compareTo(maxs[index]) > 0) { + maxs[index] = value; + } + } - List fields = new ArrayList<>(); + Comparable getMax(int index) { + return maxs[index]; + } + } - // Group columns - for (String col : groupByCols) { - int fieldIndex = getFieldIndex(col); - if (fieldIndex >= 0) { - LogicalType fieldType = sourceRowType.getTypeAt(fieldIndex); - fields.add(new RowType.RowField(col, fieldType)); - } - } + /** Build aggregate result RowType */ + public RowType buildResultRowType() { + List groupByCols = aggregateInfo.getGroupByColumns(); + List calls = aggregateInfo.getAggregateCalls(); - // Aggregate result columns - for (AggregateInfo.AggregateCall call : calls) { - String alias = call.getAlias() != null ? call.getAlias() : - call.getFunction().name().toLowerCase() + "_" + - (call.getColumn() != null ? call.getColumn() : "star"); - LogicalType resultType = getAggregateResultType(call); - fields.add(new RowType.RowField(alias, resultType)); - } + List fields = new ArrayList<>(); - return new RowType(fields); - } - - /** - * Get aggregate function result type - */ - private LogicalType getAggregateResultType(AggregateInfo.AggregateCall call) { - switch (call.getFunction()) { - case COUNT: - case COUNT_DISTINCT: - return new BigIntType(); - case SUM: - case AVG: - return new DoubleType(); - case MIN: - case MAX: - if (call.getColumn() != null) { - int fieldIndex = getFieldIndex(call.getColumn()); - if (fieldIndex >= 0) { - return sourceRowType.getTypeAt(fieldIndex); - } - } - return new DoubleType(); - default: - return new DoubleType(); - } + // Group columns + for (String col : groupByCols) { + int fieldIndex = getFieldIndex(col); + if (fieldIndex >= 0) { + LogicalType fieldType = sourceRowType.getTypeAt(fieldIndex); + fields.add(new RowType.RowField(col, fieldType)); + } + } + + // Aggregate result columns + for (AggregateInfo.AggregateCall call : calls) { + String alias = + call.getAlias() != null + ? call.getAlias() + : call.getFunction().name().toLowerCase() + + "_" + + (call.getColumn() != null ? call.getColumn() : "star"); + LogicalType resultType = getAggregateResultType(call); + fields.add(new RowType.RowField(alias, resultType)); + } + + return new RowType(fields); + } + + /** Get aggregate function result type */ + private LogicalType getAggregateResultType(AggregateInfo.AggregateCall call) { + switch (call.getFunction()) { + case COUNT: + case COUNT_DISTINCT: + return new BigIntType(); + case SUM: + case AVG: + return new DoubleType(); + case MIN: + case MAX: + if (call.getColumn() != null) { + int fieldIndex = getFieldIndex(call.getColumn()); + if (fieldIndex >= 0) { + return sourceRowType.getTypeAt(fieldIndex); + } + } + return new DoubleType(); + default: + return new DoubleType(); } + } } diff --git a/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateInfo.java b/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateInfo.java index 64a1312..f7d2c02 100644 --- a/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateInfo.java +++ b/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateInfo.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.aggregate; import java.io.Serializable; @@ -26,233 +21,218 @@ /** * Aggregate information encapsulation class. - * - *

Encapsulates information needed for aggregate push-down, including aggregate functions, target columns and group by columns. + * + *

Encapsulates information needed for aggregate push-down, including aggregate functions, target + * columns and group by columns. */ public class AggregateInfo implements Serializable { + private static final long serialVersionUID = 1L; + + /** Supported aggregate function types */ + public enum AggregateFunction { + /** COUNT(*) or COUNT(column) */ + COUNT, + /** COUNT(DISTINCT column) */ + COUNT_DISTINCT, + /** SUM(column) */ + SUM, + /** AVG(column) */ + AVG, + /** MIN(column) */ + MIN, + /** MAX(column) */ + MAX + } + + /** Single aggregate call information */ + public static class AggregateCall implements Serializable { private static final long serialVersionUID = 1L; - /** - * Supported aggregate function types - */ - public enum AggregateFunction { - /** COUNT(*) or COUNT(column) */ - COUNT, - /** COUNT(DISTINCT column) */ - COUNT_DISTINCT, - /** SUM(column) */ - SUM, - /** AVG(column) */ - AVG, - /** MIN(column) */ - MIN, - /** MAX(column) */ - MAX + private final AggregateFunction function; + private final String column; // null means COUNT(*) + private final String alias; // alias for aggregate result + + public AggregateCall(AggregateFunction function, String column, String alias) { + this.function = function; + this.column = column; + this.alias = alias; } - /** - * Single aggregate call information - */ - public static class AggregateCall implements Serializable { - private static final long serialVersionUID = 1L; - - private final AggregateFunction function; - private final String column; // null means COUNT(*) - private final String alias; // alias for aggregate result - - public AggregateCall(AggregateFunction function, String column, String alias) { - this.function = function; - this.column = column; - this.alias = alias; - } - - public AggregateFunction getFunction() { - return function; - } - - public String getColumn() { - return column; - } - - public String getAlias() { - return alias; - } - - /** - * Whether is COUNT(*) - */ - public boolean isCountStar() { - return function == AggregateFunction.COUNT && column == null; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - AggregateCall that = (AggregateCall) o; - return function == that.function && - Objects.equals(column, that.column) && - Objects.equals(alias, that.alias); - } - - @Override - public int hashCode() { - return Objects.hash(function, column, alias); - } - - @Override - public String toString() { - if (isCountStar()) { - return "COUNT(*)"; - } - return function.name() + "(" + column + ")"; - } + public AggregateFunction getFunction() { + return function; } - private final List aggregateCalls; - private final List groupByColumns; - private final int[] groupByFieldIndices; + public String getColumn() { + return column; + } - private AggregateInfo(Builder builder) { - this.aggregateCalls = Collections.unmodifiableList(new ArrayList<>(builder.aggregateCalls)); - this.groupByColumns = Collections.unmodifiableList(new ArrayList<>(builder.groupByColumns)); - this.groupByFieldIndices = builder.groupByFieldIndices != null ? - builder.groupByFieldIndices.clone() : new int[0]; + public String getAlias() { + return alias; } - public List getAggregateCalls() { - return aggregateCalls; + /** Whether is COUNT(*) */ + public boolean isCountStar() { + return function == AggregateFunction.COUNT && column == null; } - public List getGroupByColumns() { - return groupByColumns; + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AggregateCall that = (AggregateCall) o; + return function == that.function + && Objects.equals(column, that.column) + && Objects.equals(alias, that.alias); } - public int[] getGroupByFieldIndices() { - return groupByFieldIndices; + @Override + public int hashCode() { + return Objects.hash(function, column, alias); } - /** - * Whether has group by - */ - public boolean hasGroupBy() { - return !groupByColumns.isEmpty(); + @Override + public String toString() { + if (isCountStar()) { + return "COUNT(*)"; + } + return function.name() + "(" + column + ")"; + } + } + + private final List aggregateCalls; + private final List groupByColumns; + private final int[] groupByFieldIndices; + + private AggregateInfo(Builder builder) { + this.aggregateCalls = Collections.unmodifiableList(new ArrayList<>(builder.aggregateCalls)); + this.groupByColumns = Collections.unmodifiableList(new ArrayList<>(builder.groupByColumns)); + this.groupByFieldIndices = + builder.groupByFieldIndices != null ? builder.groupByFieldIndices.clone() : new int[0]; + } + + public List getAggregateCalls() { + return aggregateCalls; + } + + public List getGroupByColumns() { + return groupByColumns; + } + + public int[] getGroupByFieldIndices() { + return groupByFieldIndices; + } + + /** Whether has group by */ + public boolean hasGroupBy() { + return !groupByColumns.isEmpty(); + } + + /** Whether is simple COUNT(*) query (no group by) */ + public boolean isSimpleCountStar() { + return aggregateCalls.size() == 1 && aggregateCalls.get(0).isCountStar() && !hasGroupBy(); + } + + /** Get all required columns (aggregate columns + group by columns) */ + public List getRequiredColumns() { + List columns = new ArrayList<>(groupByColumns); + for (AggregateCall call : aggregateCalls) { + if (call.getColumn() != null && !columns.contains(call.getColumn())) { + columns.add(call.getColumn()); + } + } + return columns; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AggregateInfo that = (AggregateInfo) o; + return Objects.equals(aggregateCalls, that.aggregateCalls) + && Objects.equals(groupByColumns, that.groupByColumns); + } + + @Override + public int hashCode() { + return Objects.hash(aggregateCalls, groupByColumns); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("AggregateInfo{"); + sb.append("aggregates=").append(aggregateCalls); + if (hasGroupBy()) { + sb.append(", groupBy=").append(groupByColumns); + } + sb.append("}"); + return sb.toString(); + } + + public static Builder builder() { + return new Builder(); + } + + /** AggregateInfo builder */ + public static class Builder { + private final List aggregateCalls = new ArrayList<>(); + private final List groupByColumns = new ArrayList<>(); + private int[] groupByFieldIndices; + + public Builder addAggregateCall(AggregateFunction function, String column, String alias) { + aggregateCalls.add(new AggregateCall(function, column, alias)); + return this; } - /** - * Whether is simple COUNT(*) query (no group by) - */ - public boolean isSimpleCountStar() { - return aggregateCalls.size() == 1 && - aggregateCalls.get(0).isCountStar() && - !hasGroupBy(); + public Builder addAggregateCall(AggregateCall call) { + aggregateCalls.add(call); + return this; } - /** - * Get all required columns (aggregate columns + group by columns) - */ - public List getRequiredColumns() { - List columns = new ArrayList<>(groupByColumns); - for (AggregateCall call : aggregateCalls) { - if (call.getColumn() != null && !columns.contains(call.getColumn())) { - columns.add(call.getColumn()); - } - } - return columns; + public Builder addCountStar(String alias) { + return addAggregateCall(AggregateFunction.COUNT, null, alias); } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - AggregateInfo that = (AggregateInfo) o; - return Objects.equals(aggregateCalls, that.aggregateCalls) && - Objects.equals(groupByColumns, that.groupByColumns); + public Builder addCount(String column, String alias) { + return addAggregateCall(AggregateFunction.COUNT, column, alias); } - @Override - public int hashCode() { - return Objects.hash(aggregateCalls, groupByColumns); + public Builder addSum(String column, String alias) { + return addAggregateCall(AggregateFunction.SUM, column, alias); } - @Override - public String toString() { - StringBuilder sb = new StringBuilder("AggregateInfo{"); - sb.append("aggregates=").append(aggregateCalls); - if (hasGroupBy()) { - sb.append(", groupBy=").append(groupByColumns); - } - sb.append("}"); - return sb.toString(); + public Builder addAvg(String column, String alias) { + return addAggregateCall(AggregateFunction.AVG, column, alias); + } + + public Builder addMin(String column, String alias) { + return addAggregateCall(AggregateFunction.MIN, column, alias); + } + + public Builder addMax(String column, String alias) { + return addAggregateCall(AggregateFunction.MAX, column, alias); + } + + public Builder groupBy(List columns) { + this.groupByColumns.addAll(columns); + return this; + } + + public Builder groupBy(String... columns) { + Collections.addAll(this.groupByColumns, columns); + return this; } - public static Builder builder() { - return new Builder(); + public Builder groupByFieldIndices(int[] indices) { + this.groupByFieldIndices = indices; + return this; } - /** - * AggregateInfo builder - */ - public static class Builder { - private final List aggregateCalls = new ArrayList<>(); - private final List groupByColumns = new ArrayList<>(); - private int[] groupByFieldIndices; - - public Builder addAggregateCall(AggregateFunction function, String column, String alias) { - aggregateCalls.add(new AggregateCall(function, column, alias)); - return this; - } - - public Builder addAggregateCall(AggregateCall call) { - aggregateCalls.add(call); - return this; - } - - public Builder addCountStar(String alias) { - return addAggregateCall(AggregateFunction.COUNT, null, alias); - } - - public Builder addCount(String column, String alias) { - return addAggregateCall(AggregateFunction.COUNT, column, alias); - } - - public Builder addSum(String column, String alias) { - return addAggregateCall(AggregateFunction.SUM, column, alias); - } - - public Builder addAvg(String column, String alias) { - return addAggregateCall(AggregateFunction.AVG, column, alias); - } - - public Builder addMin(String column, String alias) { - return addAggregateCall(AggregateFunction.MIN, column, alias); - } - - public Builder addMax(String column, String alias) { - return addAggregateCall(AggregateFunction.MAX, column, alias); - } - - public Builder groupBy(List columns) { - this.groupByColumns.addAll(columns); - return this; - } - - public Builder groupBy(String... columns) { - Collections.addAll(this.groupByColumns, columns); - return this; - } - - public Builder groupByFieldIndices(int[] indices) { - this.groupByFieldIndices = indices; - return this; - } - - public AggregateInfo build() { - if (aggregateCalls.isEmpty()) { - throw new IllegalArgumentException("At least one aggregate function is required"); - } - return new AggregateInfo(this); - } + public AggregateInfo build() { + if (aggregateCalls.isEmpty()) { + throw new IllegalArgumentException("At least one aggregate function is required"); + } + return new AggregateInfo(this); } + } } diff --git a/src/main/java/org/apache/flink/connector/lance/config/LanceOptions.java b/src/main/java/org/apache/flink/connector/lance/config/LanceOptions.java index 5b20fa5..2147408 100644 --- a/src/main/java/org/apache/flink/connector/lance/config/LanceOptions.java +++ b/src/main/java/org/apache/flink/connector/lance/config/LanceOptions.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.config; import org.apache.flink.configuration.ConfigOption; @@ -30,820 +25,826 @@ /** * Lance connector configuration options. - * + * *

Defines all configuration items for Source, Sink, vector index and vector search. */ public class LanceOptions implements Serializable { - private static final long serialVersionUID = 1L; - - // ==================== Common Configuration ==================== - - /** - * Lance dataset path - */ - public static final ConfigOption PATH = ConfigOptions - .key("path") - .stringType() - .noDefaultValue() - .withDescription("Path to Lance dataset (required)"); - - // ==================== Source Configuration ==================== - - /** - * Read batch size - */ - public static final ConfigOption READ_BATCH_SIZE = ConfigOptions - .key("read.batch-size") - .intType() - .defaultValue(1024) - .withDescription("Batch size for reading, default 1024"); - - /** - * Read row limit (Limit push-down) - */ - public static final ConfigOption READ_LIMIT = ConfigOptions - .key("read.limit") - .longType() - .noDefaultValue() - .withDescription("Maximum number of rows to read (for Limit push-down)"); - - /** - * List of columns to read (comma separated) - */ - public static final ConfigOption READ_COLUMNS = ConfigOptions - .key("read.columns") - .stringType() - .noDefaultValue() - .withDescription("List of columns to read, comma separated. Empty reads all columns"); - - /** - * Data filter condition - */ - public static final ConfigOption READ_FILTER = ConfigOptions - .key("read.filter") - .stringType() - .noDefaultValue() - .withDescription("Data filter condition, using SQL WHERE clause syntax"); - - // ==================== Sink Configuration ==================== - - /** - * Write batch size - */ - public static final ConfigOption WRITE_BATCH_SIZE = ConfigOptions - .key("write.batch-size") - .intType() - .defaultValue(1024) - .withDescription("Batch size for writing, default 1024"); - - /** - * Write mode: append or overwrite - */ - public static final ConfigOption WRITE_MODE = ConfigOptions - .key("write.mode") - .stringType() - .defaultValue("append") - .withDescription("Write mode: append or overwrite, default append"); - - /** - * Maximum rows per file - */ - public static final ConfigOption WRITE_MAX_ROWS_PER_FILE = ConfigOptions - .key("write.max-rows-per-file") - .intType() - .defaultValue(1000000) - .withDescription("Maximum rows per data file, default 1000000"); - - // ==================== Vector Index Configuration ==================== - - /** - * Index type: IVF_PQ, IVF_HNSW, IVF_FLAT - */ - public static final ConfigOption INDEX_TYPE = ConfigOptions - .key("index.type") - .stringType() - .defaultValue("IVF_PQ") - .withDescription("Vector index type: IVF_PQ, IVF_HNSW, IVF_FLAT, default IVF_PQ"); - - /** - * Index column name - */ - public static final ConfigOption INDEX_COLUMN = ConfigOptions - .key("index.column") - .stringType() - .noDefaultValue() - .withDescription("Vector column name for indexing (required)"); - - /** - * IVF partition count - */ - public static final ConfigOption INDEX_NUM_PARTITIONS = ConfigOptions - .key("index.num-partitions") - .intType() - .defaultValue(256) - .withDescription("Number of IVF index partitions, default 256"); - - /** - * PQ sub-vector count - */ - public static final ConfigOption INDEX_NUM_SUB_VECTORS = ConfigOptions - .key("index.num-sub-vectors") - .intType() - .noDefaultValue() - .withDescription("Number of PQ index sub-vectors, default auto-calculated"); - - /** - * PQ quantization bits - */ - public static final ConfigOption INDEX_NUM_BITS = ConfigOptions - .key("index.num-bits") - .intType() - .defaultValue(8) - .withDescription("PQ quantization bits, default 8"); - - /** - * HNSW max level - */ - public static final ConfigOption INDEX_MAX_LEVEL = ConfigOptions - .key("index.max-level") - .intType() - .defaultValue(7) - .withDescription("HNSW index max level, default 7"); - - /** - * HNSW connections per level M - */ - public static final ConfigOption INDEX_M = ConfigOptions - .key("index.m") - .intType() - .defaultValue(16) - .withDescription("HNSW connections per level M, default 16"); - - /** - * HNSW construction search width - */ - public static final ConfigOption INDEX_EF_CONSTRUCTION = ConfigOptions - .key("index.ef-construction") - .intType() - .defaultValue(100) - .withDescription("HNSW construction search width ef_construction, default 100"); - - // ==================== Vector Search Configuration ==================== - - /** - * Vector search column name - */ - public static final ConfigOption VECTOR_COLUMN = ConfigOptions - .key("vector.column") - .stringType() - .noDefaultValue() - .withDescription("Vector search column name (required)"); - - /** - * Distance metric type: L2, Cosine, Dot - */ - public static final ConfigOption VECTOR_METRIC = ConfigOptions - .key("vector.metric") - .stringType() - .defaultValue("L2") - .withDescription("Vector distance metric type: L2 (Euclidean), Cosine, Dot, default L2"); - - /** - * IVF search probe count - */ - public static final ConfigOption VECTOR_NPROBES = ConfigOptions - .key("vector.nprobes") - .intType() - .defaultValue(20) - .withDescription("Number of IVF index search probes, default 20"); - - /** - * HNSW search width - */ - public static final ConfigOption VECTOR_EF = ConfigOptions - .key("vector.ef") - .intType() - .defaultValue(100) - .withDescription("HNSW search width ef, default 100"); - - /** - * Refine factor - */ - public static final ConfigOption VECTOR_REFINE_FACTOR = ConfigOptions - .key("vector.refine-factor") - .intType() - .noDefaultValue() - .withDescription("Vector search refine factor for improving recall"); - - // ==================== Catalog Configuration ==================== - - /** - * Default database name - */ - public static final ConfigOption DEFAULT_DATABASE = ConfigOptions - .key("default-database") - .stringType() - .defaultValue("default") - .withDescription("Catalog default database name, default 'default'"); - - /** - * Warehouse path - */ - public static final ConfigOption WAREHOUSE = ConfigOptions - .key("warehouse") - .stringType() - .noDefaultValue() - .withDescription("Lance data warehouse path (required)"); - - // ==================== Write Mode Enum ==================== - - /** - * Write mode enum - */ - public enum WriteMode { - APPEND("append"), - OVERWRITE("overwrite"); - - private final String value; - - WriteMode(String value) { - this.value = value; - } - - public String getValue() { - return value; - } + private static final long serialVersionUID = 1L; + + // ==================== Common Configuration ==================== + + /** Lance dataset path */ + public static final ConfigOption PATH = + ConfigOptions.key("path") + .stringType() + .noDefaultValue() + .withDescription("Path to Lance dataset (required)"); + + // ==================== Source Configuration ==================== + + /** Read batch size */ + public static final ConfigOption READ_BATCH_SIZE = + ConfigOptions.key("read.batch-size") + .intType() + .defaultValue(1024) + .withDescription("Batch size for reading, default 1024"); + + /** Read row limit (Limit push-down) */ + public static final ConfigOption READ_LIMIT = + ConfigOptions.key("read.limit") + .longType() + .noDefaultValue() + .withDescription("Maximum number of rows to read (for Limit push-down)"); + + /** List of columns to read (comma separated) */ + public static final ConfigOption READ_COLUMNS = + ConfigOptions.key("read.columns") + .stringType() + .noDefaultValue() + .withDescription("List of columns to read, comma separated. Empty reads all columns"); + + /** Data filter condition */ + public static final ConfigOption READ_FILTER = + ConfigOptions.key("read.filter") + .stringType() + .noDefaultValue() + .withDescription("Data filter condition, using SQL WHERE clause syntax"); + + // ==================== Sink Configuration ==================== + + /** Write batch size */ + public static final ConfigOption WRITE_BATCH_SIZE = + ConfigOptions.key("write.batch-size") + .intType() + .defaultValue(1024) + .withDescription("Batch size for writing, default 1024"); + + /** Write mode: append or overwrite */ + public static final ConfigOption WRITE_MODE = + ConfigOptions.key("write.mode") + .stringType() + .defaultValue("append") + .withDescription("Write mode: append or overwrite, default append"); + + /** Maximum rows per file */ + public static final ConfigOption WRITE_MAX_ROWS_PER_FILE = + ConfigOptions.key("write.max-rows-per-file") + .intType() + .defaultValue(1000000) + .withDescription("Maximum rows per data file, default 1000000"); + + // ==================== Vector Index Configuration ==================== + + /** Index type: IVF_PQ, IVF_HNSW, IVF_FLAT */ + public static final ConfigOption INDEX_TYPE = + ConfigOptions.key("index.type") + .stringType() + .defaultValue("IVF_PQ") + .withDescription("Vector index type: IVF_PQ, IVF_HNSW, IVF_FLAT, default IVF_PQ"); + + /** Index column name */ + public static final ConfigOption INDEX_COLUMN = + ConfigOptions.key("index.column") + .stringType() + .noDefaultValue() + .withDescription("Vector column name for indexing (required)"); + + /** IVF partition count */ + public static final ConfigOption INDEX_NUM_PARTITIONS = + ConfigOptions.key("index.num-partitions") + .intType() + .defaultValue(256) + .withDescription("Number of IVF index partitions, default 256"); + + /** PQ sub-vector count */ + public static final ConfigOption INDEX_NUM_SUB_VECTORS = + ConfigOptions.key("index.num-sub-vectors") + .intType() + .noDefaultValue() + .withDescription("Number of PQ index sub-vectors, default auto-calculated"); + + /** PQ quantization bits */ + public static final ConfigOption INDEX_NUM_BITS = + ConfigOptions.key("index.num-bits") + .intType() + .defaultValue(8) + .withDescription("PQ quantization bits, default 8"); + + /** HNSW max level */ + public static final ConfigOption INDEX_MAX_LEVEL = + ConfigOptions.key("index.max-level") + .intType() + .defaultValue(7) + .withDescription("HNSW index max level, default 7"); + + /** HNSW connections per level M */ + public static final ConfigOption INDEX_M = + ConfigOptions.key("index.m") + .intType() + .defaultValue(16) + .withDescription("HNSW connections per level M, default 16"); + + /** HNSW construction search width */ + public static final ConfigOption INDEX_EF_CONSTRUCTION = + ConfigOptions.key("index.ef-construction") + .intType() + .defaultValue(100) + .withDescription("HNSW construction search width ef_construction, default 100"); + + // ==================== Vector Search Configuration ==================== + + /** Vector search column name */ + public static final ConfigOption VECTOR_COLUMN = + ConfigOptions.key("vector.column") + .stringType() + .noDefaultValue() + .withDescription("Vector search column name (required)"); + + /** Distance metric type: L2, Cosine, Dot */ + public static final ConfigOption VECTOR_METRIC = + ConfigOptions.key("vector.metric") + .stringType() + .defaultValue("L2") + .withDescription("Vector distance metric type: L2 (Euclidean), Cosine, Dot, default L2"); + + /** IVF search probe count */ + public static final ConfigOption VECTOR_NPROBES = + ConfigOptions.key("vector.nprobes") + .intType() + .defaultValue(20) + .withDescription("Number of IVF index search probes, default 20"); + + /** HNSW search width */ + public static final ConfigOption VECTOR_EF = + ConfigOptions.key("vector.ef") + .intType() + .defaultValue(100) + .withDescription("HNSW search width ef, default 100"); + + /** Refine factor */ + public static final ConfigOption VECTOR_REFINE_FACTOR = + ConfigOptions.key("vector.refine-factor") + .intType() + .noDefaultValue() + .withDescription("Vector search refine factor for improving recall"); + + // ==================== Catalog Configuration ==================== + + /** Default database name */ + public static final ConfigOption DEFAULT_DATABASE = + ConfigOptions.key("default-database") + .stringType() + .defaultValue("default") + .withDescription("Catalog default database name, default 'default'"); + + /** Warehouse path */ + public static final ConfigOption WAREHOUSE = + ConfigOptions.key("warehouse") + .stringType() + .noDefaultValue() + .withDescription("Lance data warehouse path (required)"); + + // ==================== Write Mode Enum ==================== + + /** Write mode enum */ + public enum WriteMode { + APPEND("append"), + OVERWRITE("overwrite"); + + private final String value; + + WriteMode(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static WriteMode fromValue(String value) { + for (WriteMode mode : values()) { + if (mode.value.equalsIgnoreCase(value)) { + return mode; + } + } + throw new IllegalArgumentException( + "Unsupported write mode: " + value + ", supported modes: append, overwrite"); + } + } + + // ==================== Index Type Enum ==================== + + /** Index type enum */ + public enum IndexType { + IVF_PQ("IVF_PQ"), + IVF_HNSW("IVF_HNSW"), + IVF_FLAT("IVF_FLAT"); + + private final String value; + + IndexType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static IndexType fromValue(String value) { + for (IndexType type : values()) { + if (type.value.equalsIgnoreCase(value)) { + return type; + } + } + throw new IllegalArgumentException( + "Unsupported index type: " + value + ", supported types: IVF_PQ, IVF_HNSW, IVF_FLAT"); + } + } + + // ==================== Metric Type Enum ==================== + + /** Distance metric type enum */ + public enum MetricType { + L2("L2"), + COSINE("Cosine"), + DOT("Dot"); + + private final String value; + + MetricType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static MetricType fromValue(String value) { + for (MetricType type : values()) { + if (type.value.equalsIgnoreCase(value)) { + return type; + } + } + throw new IllegalArgumentException( + "Unsupported metric type: " + value + ", supported types: L2, Cosine, Dot"); + } + } + + // ==================== Configuration Class ==================== + + private final String path; + private final int readBatchSize; + private final Long readLimit; + private final List readColumns; + private final String readFilter; + private final int writeBatchSize; + private final WriteMode writeMode; + private final int writeMaxRowsPerFile; + private final IndexType indexType; + private final String indexColumn; + private final int indexNumPartitions; + private final Integer indexNumSubVectors; + private final int indexNumBits; + private final int indexMaxLevel; + private final int indexM; + private final int indexEfConstruction; + private final String vectorColumn; + private final MetricType vectorMetric; + private final int vectorNprobes; + private final int vectorEf; + private final Integer vectorRefineFactor; + private final String defaultDatabase; + private final String warehouse; + + private LanceOptions(Builder builder) { + this.path = builder.path; + this.readBatchSize = builder.readBatchSize; + this.readLimit = builder.readLimit; + this.readColumns = builder.readColumns; + this.readFilter = builder.readFilter; + this.writeBatchSize = builder.writeBatchSize; + this.writeMode = builder.writeMode; + this.writeMaxRowsPerFile = builder.writeMaxRowsPerFile; + this.indexType = builder.indexType; + this.indexColumn = builder.indexColumn; + this.indexNumPartitions = builder.indexNumPartitions; + this.indexNumSubVectors = builder.indexNumSubVectors; + this.indexNumBits = builder.indexNumBits; + this.indexMaxLevel = builder.indexMaxLevel; + this.indexM = builder.indexM; + this.indexEfConstruction = builder.indexEfConstruction; + this.vectorColumn = builder.vectorColumn; + this.vectorMetric = builder.vectorMetric; + this.vectorNprobes = builder.vectorNprobes; + this.vectorEf = builder.vectorEf; + this.vectorRefineFactor = builder.vectorRefineFactor; + this.defaultDatabase = builder.defaultDatabase; + this.warehouse = builder.warehouse; + } + + // ==================== Getter Methods ==================== + + public String getPath() { + return path; + } + + public int getReadBatchSize() { + return readBatchSize; + } + + public Long getReadLimit() { + return readLimit; + } + + public List getReadColumns() { + return readColumns; + } + + public String getReadFilter() { + return readFilter; + } + + public int getWriteBatchSize() { + return writeBatchSize; + } + + public WriteMode getWriteMode() { + return writeMode; + } + + public int getWriteMaxRowsPerFile() { + return writeMaxRowsPerFile; + } + + public IndexType getIndexType() { + return indexType; + } + + public String getIndexColumn() { + return indexColumn; + } + + public int getIndexNumPartitions() { + return indexNumPartitions; + } + + public Integer getIndexNumSubVectors() { + return indexNumSubVectors; + } + + public int getIndexNumBits() { + return indexNumBits; + } + + public int getIndexMaxLevel() { + return indexMaxLevel; + } + + public int getIndexM() { + return indexM; + } + + public int getIndexEfConstruction() { + return indexEfConstruction; + } + + public String getVectorColumn() { + return vectorColumn; + } + + public MetricType getVectorMetric() { + return vectorMetric; + } + + public int getVectorNprobes() { + return vectorNprobes; + } + + public int getVectorEf() { + return vectorEf; + } + + public Integer getVectorRefineFactor() { + return vectorRefineFactor; + } + + public String getDefaultDatabase() { + return defaultDatabase; + } + + public String getWarehouse() { + return warehouse; + } + + // ==================== Builder ==================== + + public static Builder builder() { + return new Builder(); + } + + /** Create LanceOptions from Flink Configuration */ + public static LanceOptions fromConfiguration(Configuration config) { + Builder builder = builder(); + + // Common configuration + if (config.contains(PATH)) { + builder.path(config.get(PATH)); + } - public static WriteMode fromValue(String value) { - for (WriteMode mode : values()) { - if (mode.value.equalsIgnoreCase(value)) { - return mode; - } - } - throw new IllegalArgumentException("Unsupported write mode: " + value + ", supported modes: append, overwrite"); - } + // Source configuration + builder.readBatchSize(config.get(READ_BATCH_SIZE)); + if (config.contains(READ_LIMIT)) { + builder.readLimit(config.get(READ_LIMIT)); } - - // ==================== Index Type Enum ==================== - - /** - * Index type enum - */ - public enum IndexType { - IVF_PQ("IVF_PQ"), - IVF_HNSW("IVF_HNSW"), - IVF_FLAT("IVF_FLAT"); - - private final String value; - - IndexType(String value) { - this.value = value; - } - - public String getValue() { - return value; - } - - public static IndexType fromValue(String value) { - for (IndexType type : values()) { - if (type.value.equalsIgnoreCase(value)) { - return type; - } - } - throw new IllegalArgumentException("Unsupported index type: " + value + ", supported types: IVF_PQ, IVF_HNSW, IVF_FLAT"); - } + if (config.contains(READ_COLUMNS)) { + String columnsStr = config.get(READ_COLUMNS); + if (columnsStr != null && !columnsStr.isEmpty()) { + builder.readColumns(Arrays.asList(columnsStr.split(","))); + } + } + if (config.contains(READ_FILTER)) { + builder.readFilter(config.get(READ_FILTER)); } - // ==================== Metric Type Enum ==================== - - /** - * Distance metric type enum - */ - public enum MetricType { - L2("L2"), - COSINE("Cosine"), - DOT("Dot"); - - private final String value; - - MetricType(String value) { - this.value = value; - } - - public String getValue() { - return value; - } + // Sink configuration + builder.writeBatchSize(config.get(WRITE_BATCH_SIZE)); + builder.writeMode(WriteMode.fromValue(config.get(WRITE_MODE))); + builder.writeMaxRowsPerFile(config.get(WRITE_MAX_ROWS_PER_FILE)); - public static MetricType fromValue(String value) { - for (MetricType type : values()) { - if (type.value.equalsIgnoreCase(value)) { - return type; - } - } - throw new IllegalArgumentException("Unsupported metric type: " + value + ", supported types: L2, Cosine, Dot"); - } + // Index configuration + builder.indexType(IndexType.fromValue(config.get(INDEX_TYPE))); + if (config.contains(INDEX_COLUMN)) { + builder.indexColumn(config.get(INDEX_COLUMN)); } + builder.indexNumPartitions(config.get(INDEX_NUM_PARTITIONS)); + if (config.contains(INDEX_NUM_SUB_VECTORS)) { + builder.indexNumSubVectors(config.get(INDEX_NUM_SUB_VECTORS)); + } + builder.indexNumBits(config.get(INDEX_NUM_BITS)); + builder.indexMaxLevel(config.get(INDEX_MAX_LEVEL)); + builder.indexM(config.get(INDEX_M)); + builder.indexEfConstruction(config.get(INDEX_EF_CONSTRUCTION)); - // ==================== Configuration Class ==================== - - private final String path; - private final int readBatchSize; - private final Long readLimit; - private final List readColumns; - private final String readFilter; - private final int writeBatchSize; - private final WriteMode writeMode; - private final int writeMaxRowsPerFile; - private final IndexType indexType; - private final String indexColumn; - private final int indexNumPartitions; - private final Integer indexNumSubVectors; - private final int indexNumBits; - private final int indexMaxLevel; - private final int indexM; - private final int indexEfConstruction; - private final String vectorColumn; - private final MetricType vectorMetric; - private final int vectorNprobes; - private final int vectorEf; - private final Integer vectorRefineFactor; - private final String defaultDatabase; - private final String warehouse; + // Vector search configuration + if (config.contains(VECTOR_COLUMN)) { + builder.vectorColumn(config.get(VECTOR_COLUMN)); + } + builder.vectorMetric(MetricType.fromValue(config.get(VECTOR_METRIC))); + builder.vectorNprobes(config.get(VECTOR_NPROBES)); + builder.vectorEf(config.get(VECTOR_EF)); + if (config.contains(VECTOR_REFINE_FACTOR)) { + builder.vectorRefineFactor(config.get(VECTOR_REFINE_FACTOR)); + } - private LanceOptions(Builder builder) { - this.path = builder.path; - this.readBatchSize = builder.readBatchSize; - this.readLimit = builder.readLimit; - this.readColumns = builder.readColumns; - this.readFilter = builder.readFilter; - this.writeBatchSize = builder.writeBatchSize; - this.writeMode = builder.writeMode; - this.writeMaxRowsPerFile = builder.writeMaxRowsPerFile; - this.indexType = builder.indexType; - this.indexColumn = builder.indexColumn; - this.indexNumPartitions = builder.indexNumPartitions; - this.indexNumSubVectors = builder.indexNumSubVectors; - this.indexNumBits = builder.indexNumBits; - this.indexMaxLevel = builder.indexMaxLevel; - this.indexM = builder.indexM; - this.indexEfConstruction = builder.indexEfConstruction; - this.vectorColumn = builder.vectorColumn; - this.vectorMetric = builder.vectorMetric; - this.vectorNprobes = builder.vectorNprobes; - this.vectorEf = builder.vectorEf; - this.vectorRefineFactor = builder.vectorRefineFactor; - this.defaultDatabase = builder.defaultDatabase; - this.warehouse = builder.warehouse; + // Catalog configuration + builder.defaultDatabase(config.get(DEFAULT_DATABASE)); + if (config.contains(WAREHOUSE)) { + builder.warehouse(config.get(WAREHOUSE)); } - // ==================== Getter Methods ==================== + return builder.build(); + } - public String getPath() { - return path; - } + /** Configuration builder */ + public static class Builder { + private String path; + private int readBatchSize = 1024; + private Long readLimit; + private List readColumns = Collections.emptyList(); + private String readFilter; + private int writeBatchSize = 1024; + private WriteMode writeMode = WriteMode.APPEND; + private int writeMaxRowsPerFile = 1000000; + private IndexType indexType = IndexType.IVF_PQ; + private String indexColumn; + private int indexNumPartitions = 256; + private Integer indexNumSubVectors; + private int indexNumBits = 8; + private int indexMaxLevel = 7; + private int indexM = 16; + private int indexEfConstruction = 100; + private String vectorColumn; + private MetricType vectorMetric = MetricType.L2; + private int vectorNprobes = 20; + private int vectorEf = 100; + private Integer vectorRefineFactor; + private String defaultDatabase = "default"; + private String warehouse; - public int getReadBatchSize() { - return readBatchSize; + public Builder path(String path) { + this.path = path; + return this; } - public Long getReadLimit() { - return readLimit; + public Builder readBatchSize(int readBatchSize) { + this.readBatchSize = readBatchSize; + return this; } - public List getReadColumns() { - return readColumns; + public Builder readLimit(Long readLimit) { + this.readLimit = readLimit; + return this; } - public String getReadFilter() { - return readFilter; + public Builder readColumns(List readColumns) { + this.readColumns = readColumns != null ? readColumns : Collections.emptyList(); + return this; } - public int getWriteBatchSize() { - return writeBatchSize; + public Builder readFilter(String readFilter) { + this.readFilter = readFilter; + return this; } - public WriteMode getWriteMode() { - return writeMode; + public Builder writeBatchSize(int writeBatchSize) { + this.writeBatchSize = writeBatchSize; + return this; } - public int getWriteMaxRowsPerFile() { - return writeMaxRowsPerFile; + public Builder writeMode(WriteMode writeMode) { + this.writeMode = writeMode; + return this; } - public IndexType getIndexType() { - return indexType; + public Builder writeMaxRowsPerFile(int writeMaxRowsPerFile) { + this.writeMaxRowsPerFile = writeMaxRowsPerFile; + return this; } - public String getIndexColumn() { - return indexColumn; + public Builder indexType(IndexType indexType) { + this.indexType = indexType; + return this; } - public int getIndexNumPartitions() { - return indexNumPartitions; + public Builder indexColumn(String indexColumn) { + this.indexColumn = indexColumn; + return this; } - public Integer getIndexNumSubVectors() { - return indexNumSubVectors; + public Builder indexNumPartitions(int indexNumPartitions) { + this.indexNumPartitions = indexNumPartitions; + return this; } - public int getIndexNumBits() { - return indexNumBits; + public Builder indexNumSubVectors(Integer indexNumSubVectors) { + this.indexNumSubVectors = indexNumSubVectors; + return this; } - public int getIndexMaxLevel() { - return indexMaxLevel; + public Builder indexNumBits(int indexNumBits) { + this.indexNumBits = indexNumBits; + return this; } - public int getIndexM() { - return indexM; + public Builder indexMaxLevel(int indexMaxLevel) { + this.indexMaxLevel = indexMaxLevel; + return this; } - public int getIndexEfConstruction() { - return indexEfConstruction; + public Builder indexM(int indexM) { + this.indexM = indexM; + return this; } - public String getVectorColumn() { - return vectorColumn; + public Builder indexEfConstruction(int indexEfConstruction) { + this.indexEfConstruction = indexEfConstruction; + return this; } - public MetricType getVectorMetric() { - return vectorMetric; + public Builder vectorColumn(String vectorColumn) { + this.vectorColumn = vectorColumn; + return this; } - public int getVectorNprobes() { - return vectorNprobes; + public Builder vectorMetric(MetricType vectorMetric) { + this.vectorMetric = vectorMetric; + return this; } - public int getVectorEf() { - return vectorEf; + public Builder vectorNprobes(int vectorNprobes) { + this.vectorNprobes = vectorNprobes; + return this; } - public Integer getVectorRefineFactor() { - return vectorRefineFactor; + public Builder vectorEf(int vectorEf) { + this.vectorEf = vectorEf; + return this; } - public String getDefaultDatabase() { - return defaultDatabase; + public Builder vectorRefineFactor(Integer vectorRefineFactor) { + this.vectorRefineFactor = vectorRefineFactor; + return this; } - public String getWarehouse() { - return warehouse; + public Builder defaultDatabase(String defaultDatabase) { + this.defaultDatabase = defaultDatabase; + return this; } - // ==================== Builder ==================== - - public static Builder builder() { - return new Builder(); + public Builder warehouse(String warehouse) { + this.warehouse = warehouse; + return this; } - /** - * Create LanceOptions from Flink Configuration - */ - public static LanceOptions fromConfiguration(Configuration config) { - Builder builder = builder(); - - // Common configuration - if (config.contains(PATH)) { - builder.path(config.get(PATH)); - } - - // Source configuration - builder.readBatchSize(config.get(READ_BATCH_SIZE)); - if (config.contains(READ_LIMIT)) { - builder.readLimit(config.get(READ_LIMIT)); - } - if (config.contains(READ_COLUMNS)) { - String columnsStr = config.get(READ_COLUMNS); - if (columnsStr != null && !columnsStr.isEmpty()) { - builder.readColumns(Arrays.asList(columnsStr.split(","))); - } - } - if (config.contains(READ_FILTER)) { - builder.readFilter(config.get(READ_FILTER)); - } - - // Sink configuration - builder.writeBatchSize(config.get(WRITE_BATCH_SIZE)); - builder.writeMode(WriteMode.fromValue(config.get(WRITE_MODE))); - builder.writeMaxRowsPerFile(config.get(WRITE_MAX_ROWS_PER_FILE)); - - // Index configuration - builder.indexType(IndexType.fromValue(config.get(INDEX_TYPE))); - if (config.contains(INDEX_COLUMN)) { - builder.indexColumn(config.get(INDEX_COLUMN)); - } - builder.indexNumPartitions(config.get(INDEX_NUM_PARTITIONS)); - if (config.contains(INDEX_NUM_SUB_VECTORS)) { - builder.indexNumSubVectors(config.get(INDEX_NUM_SUB_VECTORS)); - } - builder.indexNumBits(config.get(INDEX_NUM_BITS)); - builder.indexMaxLevel(config.get(INDEX_MAX_LEVEL)); - builder.indexM(config.get(INDEX_M)); - builder.indexEfConstruction(config.get(INDEX_EF_CONSTRUCTION)); - - // Vector search configuration - if (config.contains(VECTOR_COLUMN)) { - builder.vectorColumn(config.get(VECTOR_COLUMN)); - } - builder.vectorMetric(MetricType.fromValue(config.get(VECTOR_METRIC))); - builder.vectorNprobes(config.get(VECTOR_NPROBES)); - builder.vectorEf(config.get(VECTOR_EF)); - if (config.contains(VECTOR_REFINE_FACTOR)) { - builder.vectorRefineFactor(config.get(VECTOR_REFINE_FACTOR)); - } - - // Catalog configuration - builder.defaultDatabase(config.get(DEFAULT_DATABASE)); - if (config.contains(WAREHOUSE)) { - builder.warehouse(config.get(WAREHOUSE)); - } - - return builder.build(); - } - - /** - * Configuration builder - */ - public static class Builder { - private String path; - private int readBatchSize = 1024; - private Long readLimit; - private List readColumns = Collections.emptyList(); - private String readFilter; - private int writeBatchSize = 1024; - private WriteMode writeMode = WriteMode.APPEND; - private int writeMaxRowsPerFile = 1000000; - private IndexType indexType = IndexType.IVF_PQ; - private String indexColumn; - private int indexNumPartitions = 256; - private Integer indexNumSubVectors; - private int indexNumBits = 8; - private int indexMaxLevel = 7; - private int indexM = 16; - private int indexEfConstruction = 100; - private String vectorColumn; - private MetricType vectorMetric = MetricType.L2; - private int vectorNprobes = 20; - private int vectorEf = 100; - private Integer vectorRefineFactor; - private String defaultDatabase = "default"; - private String warehouse; - - public Builder path(String path) { - this.path = path; - return this; - } - - public Builder readBatchSize(int readBatchSize) { - this.readBatchSize = readBatchSize; - return this; - } - - public Builder readLimit(Long readLimit) { - this.readLimit = readLimit; - return this; - } - - public Builder readColumns(List readColumns) { - this.readColumns = readColumns != null ? readColumns : Collections.emptyList(); - return this; - } - - public Builder readFilter(String readFilter) { - this.readFilter = readFilter; - return this; - } - - public Builder writeBatchSize(int writeBatchSize) { - this.writeBatchSize = writeBatchSize; - return this; - } - - public Builder writeMode(WriteMode writeMode) { - this.writeMode = writeMode; - return this; - } - - public Builder writeMaxRowsPerFile(int writeMaxRowsPerFile) { - this.writeMaxRowsPerFile = writeMaxRowsPerFile; - return this; - } - - public Builder indexType(IndexType indexType) { - this.indexType = indexType; - return this; - } - - public Builder indexColumn(String indexColumn) { - this.indexColumn = indexColumn; - return this; - } - - public Builder indexNumPartitions(int indexNumPartitions) { - this.indexNumPartitions = indexNumPartitions; - return this; - } - - public Builder indexNumSubVectors(Integer indexNumSubVectors) { - this.indexNumSubVectors = indexNumSubVectors; - return this; - } - - public Builder indexNumBits(int indexNumBits) { - this.indexNumBits = indexNumBits; - return this; - } - - public Builder indexMaxLevel(int indexMaxLevel) { - this.indexMaxLevel = indexMaxLevel; - return this; - } - - public Builder indexM(int indexM) { - this.indexM = indexM; - return this; - } - - public Builder indexEfConstruction(int indexEfConstruction) { - this.indexEfConstruction = indexEfConstruction; - return this; - } - - public Builder vectorColumn(String vectorColumn) { - this.vectorColumn = vectorColumn; - return this; - } - - public Builder vectorMetric(MetricType vectorMetric) { - this.vectorMetric = vectorMetric; - return this; - } - - public Builder vectorNprobes(int vectorNprobes) { - this.vectorNprobes = vectorNprobes; - return this; - } - - public Builder vectorEf(int vectorEf) { - this.vectorEf = vectorEf; - return this; - } - - public Builder vectorRefineFactor(Integer vectorRefineFactor) { - this.vectorRefineFactor = vectorRefineFactor; - return this; - } - - public Builder defaultDatabase(String defaultDatabase) { - this.defaultDatabase = defaultDatabase; - return this; - } - - public Builder warehouse(String warehouse) { - this.warehouse = warehouse; - return this; - } + /** Build LanceOptions instance with validation */ + public LanceOptions build() { + validate(); + return new LanceOptions(this); + } - /** - * Build LanceOptions instance with validation - */ - public LanceOptions build() { - validate(); - return new LanceOptions(this); - } + /** Validate configuration */ + private void validate() { + // Validate read batch size + if (readBatchSize <= 0) { + throw new IllegalArgumentException( + "read.batch-size must be greater than 0, current value: " + readBatchSize); + } - /** - * Validate configuration - */ - private void validate() { - // Validate read batch size - if (readBatchSize <= 0) { - throw new IllegalArgumentException("read.batch-size must be greater than 0, current value: " + readBatchSize); - } - - // Validate Limit (if set) - if (readLimit != null && readLimit < 0) { - throw new IllegalArgumentException("read.limit must be greater than or equal to 0, current value: " + readLimit); - } - - // Validate write batch size - if (writeBatchSize <= 0) { - throw new IllegalArgumentException("write.batch-size must be greater than 0, current value: " + writeBatchSize); - } - - // Validate max rows per file - if (writeMaxRowsPerFile <= 0) { - throw new IllegalArgumentException("write.max-rows-per-file must be greater than 0, current value: " + writeMaxRowsPerFile); - } - - // Validate index partition count - if (indexNumPartitions <= 0) { - throw new IllegalArgumentException("index.num-partitions must be greater than 0, current value: " + indexNumPartitions); - } - - // Validate PQ sub-vector count - if (indexNumSubVectors != null && indexNumSubVectors <= 0) { - throw new IllegalArgumentException("index.num-sub-vectors must be greater than 0, current value: " + indexNumSubVectors); - } - - // Validate PQ quantization bits - if (indexNumBits <= 0 || indexNumBits > 16) { - throw new IllegalArgumentException("index.num-bits must be between 1 and 16, current value: " + indexNumBits); - } - - // Validate HNSW parameters - if (indexMaxLevel <= 0) { - throw new IllegalArgumentException("index.max-level must be greater than 0, current value: " + indexMaxLevel); - } - - if (indexM <= 0) { - throw new IllegalArgumentException("index.m must be greater than 0, current value: " + indexM); - } - - if (indexEfConstruction <= 0) { - throw new IllegalArgumentException("index.ef-construction must be greater than 0, current value: " + indexEfConstruction); - } - - // Validate vector search parameters - if (vectorNprobes <= 0) { - throw new IllegalArgumentException("vector.nprobes must be greater than 0, current value: " + vectorNprobes); - } - - if (vectorEf <= 0) { - throw new IllegalArgumentException("vector.ef must be greater than 0, current value: " + vectorEf); - } - - if (vectorRefineFactor != null && vectorRefineFactor <= 0) { - throw new IllegalArgumentException("vector.refine-factor must be greater than 0, current value: " + vectorRefineFactor); - } - } - } + // Validate Limit (if set) + if (readLimit != null && readLimit < 0) { + throw new IllegalArgumentException( + "read.limit must be greater than or equal to 0, current value: " + readLimit); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - LanceOptions that = (LanceOptions) o; - return readBatchSize == that.readBatchSize && - Objects.equals(readLimit, that.readLimit) && - writeBatchSize == that.writeBatchSize && - writeMaxRowsPerFile == that.writeMaxRowsPerFile && - indexNumPartitions == that.indexNumPartitions && - indexNumBits == that.indexNumBits && - indexMaxLevel == that.indexMaxLevel && - indexM == that.indexM && - indexEfConstruction == that.indexEfConstruction && - vectorNprobes == that.vectorNprobes && - vectorEf == that.vectorEf && - Objects.equals(path, that.path) && - Objects.equals(readColumns, that.readColumns) && - Objects.equals(readFilter, that.readFilter) && - writeMode == that.writeMode && - indexType == that.indexType && - Objects.equals(indexColumn, that.indexColumn) && - Objects.equals(indexNumSubVectors, that.indexNumSubVectors) && - Objects.equals(vectorColumn, that.vectorColumn) && - vectorMetric == that.vectorMetric && - Objects.equals(vectorRefineFactor, that.vectorRefineFactor) && - Objects.equals(defaultDatabase, that.defaultDatabase) && - Objects.equals(warehouse, that.warehouse); - } - - @Override - public int hashCode() { - return Objects.hash(path, readBatchSize, readLimit, readColumns, readFilter, writeBatchSize, writeMode, - writeMaxRowsPerFile, indexType, indexColumn, indexNumPartitions, indexNumSubVectors, - indexNumBits, indexMaxLevel, indexM, indexEfConstruction, vectorColumn, vectorMetric, - vectorNprobes, vectorEf, vectorRefineFactor, defaultDatabase, warehouse); - } - - @Override - public String toString() { - return "LanceOptions{" + - "path='" + path + '\'' + - ", readBatchSize=" + readBatchSize + - ", readLimit=" + readLimit + - ", readColumns=" + readColumns + - ", readFilter='" + readFilter + '\'' + - ", writeBatchSize=" + writeBatchSize + - ", writeMode=" + writeMode + - ", writeMaxRowsPerFile=" + writeMaxRowsPerFile + - ", indexType=" + indexType + - ", indexColumn='" + indexColumn + '\'' + - ", indexNumPartitions=" + indexNumPartitions + - ", indexNumSubVectors=" + indexNumSubVectors + - ", indexNumBits=" + indexNumBits + - ", indexMaxLevel=" + indexMaxLevel + - ", indexM=" + indexM + - ", indexEfConstruction=" + indexEfConstruction + - ", vectorColumn='" + vectorColumn + '\'' + - ", vectorMetric=" + vectorMetric + - ", vectorNprobes=" + vectorNprobes + - ", vectorEf=" + vectorEf + - ", vectorRefineFactor=" + vectorRefineFactor + - ", defaultDatabase='" + defaultDatabase + '\'' + - ", warehouse='" + warehouse + '\'' + - '}'; - } + // Validate write batch size + if (writeBatchSize <= 0) { + throw new IllegalArgumentException( + "write.batch-size must be greater than 0, current value: " + writeBatchSize); + } + + // Validate max rows per file + if (writeMaxRowsPerFile <= 0) { + throw new IllegalArgumentException( + "write.max-rows-per-file must be greater than 0, current value: " + + writeMaxRowsPerFile); + } + + // Validate index partition count + if (indexNumPartitions <= 0) { + throw new IllegalArgumentException( + "index.num-partitions must be greater than 0, current value: " + indexNumPartitions); + } + + // Validate PQ sub-vector count + if (indexNumSubVectors != null && indexNumSubVectors <= 0) { + throw new IllegalArgumentException( + "index.num-sub-vectors must be greater than 0, current value: " + indexNumSubVectors); + } + + // Validate PQ quantization bits + if (indexNumBits <= 0 || indexNumBits > 16) { + throw new IllegalArgumentException( + "index.num-bits must be between 1 and 16, current value: " + indexNumBits); + } + + // Validate HNSW parameters + if (indexMaxLevel <= 0) { + throw new IllegalArgumentException( + "index.max-level must be greater than 0, current value: " + indexMaxLevel); + } + + if (indexM <= 0) { + throw new IllegalArgumentException( + "index.m must be greater than 0, current value: " + indexM); + } + + if (indexEfConstruction <= 0) { + throw new IllegalArgumentException( + "index.ef-construction must be greater than 0, current value: " + indexEfConstruction); + } + + // Validate vector search parameters + if (vectorNprobes <= 0) { + throw new IllegalArgumentException( + "vector.nprobes must be greater than 0, current value: " + vectorNprobes); + } + + if (vectorEf <= 0) { + throw new IllegalArgumentException( + "vector.ef must be greater than 0, current value: " + vectorEf); + } + + if (vectorRefineFactor != null && vectorRefineFactor <= 0) { + throw new IllegalArgumentException( + "vector.refine-factor must be greater than 0, current value: " + vectorRefineFactor); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LanceOptions that = (LanceOptions) o; + return readBatchSize == that.readBatchSize + && Objects.equals(readLimit, that.readLimit) + && writeBatchSize == that.writeBatchSize + && writeMaxRowsPerFile == that.writeMaxRowsPerFile + && indexNumPartitions == that.indexNumPartitions + && indexNumBits == that.indexNumBits + && indexMaxLevel == that.indexMaxLevel + && indexM == that.indexM + && indexEfConstruction == that.indexEfConstruction + && vectorNprobes == that.vectorNprobes + && vectorEf == that.vectorEf + && Objects.equals(path, that.path) + && Objects.equals(readColumns, that.readColumns) + && Objects.equals(readFilter, that.readFilter) + && writeMode == that.writeMode + && indexType == that.indexType + && Objects.equals(indexColumn, that.indexColumn) + && Objects.equals(indexNumSubVectors, that.indexNumSubVectors) + && Objects.equals(vectorColumn, that.vectorColumn) + && vectorMetric == that.vectorMetric + && Objects.equals(vectorRefineFactor, that.vectorRefineFactor) + && Objects.equals(defaultDatabase, that.defaultDatabase) + && Objects.equals(warehouse, that.warehouse); + } + + @Override + public int hashCode() { + return Objects.hash( + path, + readBatchSize, + readLimit, + readColumns, + readFilter, + writeBatchSize, + writeMode, + writeMaxRowsPerFile, + indexType, + indexColumn, + indexNumPartitions, + indexNumSubVectors, + indexNumBits, + indexMaxLevel, + indexM, + indexEfConstruction, + vectorColumn, + vectorMetric, + vectorNprobes, + vectorEf, + vectorRefineFactor, + defaultDatabase, + warehouse); + } + + @Override + public String toString() { + return "LanceOptions{" + + "path='" + + path + + '\'' + + ", readBatchSize=" + + readBatchSize + + ", readLimit=" + + readLimit + + ", readColumns=" + + readColumns + + ", readFilter='" + + readFilter + + '\'' + + ", writeBatchSize=" + + writeBatchSize + + ", writeMode=" + + writeMode + + ", writeMaxRowsPerFile=" + + writeMaxRowsPerFile + + ", indexType=" + + indexType + + ", indexColumn='" + + indexColumn + + '\'' + + ", indexNumPartitions=" + + indexNumPartitions + + ", indexNumSubVectors=" + + indexNumSubVectors + + ", indexNumBits=" + + indexNumBits + + ", indexMaxLevel=" + + indexMaxLevel + + ", indexM=" + + indexM + + ", indexEfConstruction=" + + indexEfConstruction + + ", vectorColumn='" + + vectorColumn + + '\'' + + ", vectorMetric=" + + vectorMetric + + ", vectorNprobes=" + + vectorNprobes + + ", vectorEf=" + + vectorEf + + ", vectorRefineFactor=" + + vectorRefineFactor + + ", defaultDatabase='" + + defaultDatabase + + '\'' + + ", warehouse='" + + warehouse + + '\'' + + '}'; + } } diff --git a/src/main/java/org/apache/flink/connector/lance/converter/LanceTypeConverter.java b/src/main/java/org/apache/flink/connector/lance/converter/LanceTypeConverter.java index d3b8e89..8436eb4 100644 --- a/src/main/java/org/apache/flink/connector/lance/converter/LanceTypeConverter.java +++ b/src/main/java/org/apache/flink/connector/lance/converter/LanceTypeConverter.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.converter; import org.apache.flink.table.api.DataTypes; @@ -52,389 +47,390 @@ /** * Type converter between Lance/Arrow and Flink types. - * + * *

Supported type mappings: + * *

    - *
  • Int8 <-> TINYINT
  • - *
  • Int16 <-> SMALLINT
  • - *
  • Int32 <-> INT
  • - *
  • Int64 <-> BIGINT
  • - *
  • Float32 <-> FLOAT
  • - *
  • Float64 <-> DOUBLE
  • - *
  • String/LargeString <-> STRING
  • - *
  • Boolean <-> BOOLEAN
  • - *
  • Binary/LargeBinary <-> BYTES
  • - *
  • Date32 <-> DATE
  • - *
  • Timestamp <-> TIMESTAMP
  • - *
  • FixedSizeList <-> ARRAY
  • - *
  • FixedSizeList <-> ARRAY
  • + *
  • Int8 <-> TINYINT + *
  • Int16 <-> SMALLINT + *
  • Int32 <-> INT + *
  • Int64 <-> BIGINT + *
  • Float32 <-> FLOAT + *
  • Float64 <-> DOUBLE + *
  • String/LargeString <-> STRING + *
  • Boolean <-> BOOLEAN + *
  • Binary/LargeBinary <-> BYTES + *
  • Date32 <-> DATE + *
  • Timestamp <-> TIMESTAMP + *
  • FixedSizeList <-> ARRAY + *
  • FixedSizeList <-> ARRAY *
*/ public class LanceTypeConverter implements Serializable { - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(LanceTypeConverter.class); + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(LanceTypeConverter.class); - /** - * Convert Arrow Schema to Flink RowType - * - * @param schema Arrow Schema - * @return Flink RowType - */ - public static RowType toFlinkRowType(Schema schema) { - List fields = new ArrayList<>(); - for (Field field : schema.getFields()) { - LogicalType logicalType = arrowTypeToFlinkType(field); - fields.add(new RowType.RowField(field.getName(), logicalType)); - } - return new RowType(fields); + /** + * Convert Arrow Schema to Flink RowType + * + * @param schema Arrow Schema + * @return Flink RowType + */ + public static RowType toFlinkRowType(Schema schema) { + List fields = new ArrayList<>(); + for (Field field : schema.getFields()) { + LogicalType logicalType = arrowTypeToFlinkType(field); + fields.add(new RowType.RowField(field.getName(), logicalType)); } + return new RowType(fields); + } - /** - * Convert Flink RowType to Arrow Schema - * - * @param rowType Flink RowType - * @return Arrow Schema - */ - public static Schema toArrowSchema(RowType rowType) { - List fields = new ArrayList<>(); - for (RowType.RowField rowField : rowType.getFields()) { - Field arrowField = flinkTypeToArrowField(rowField.getName(), rowField.getType()); - fields.add(arrowField); - } - return new Schema(fields); + /** + * Convert Flink RowType to Arrow Schema + * + * @param rowType Flink RowType + * @return Arrow Schema + */ + public static Schema toArrowSchema(RowType rowType) { + List fields = new ArrayList<>(); + for (RowType.RowField rowField : rowType.getFields()) { + Field arrowField = flinkTypeToArrowField(rowField.getName(), rowField.getType()); + fields.add(arrowField); } + return new Schema(fields); + } - /** - * Convert Arrow Field to Flink LogicalType - * - * @param field Arrow Field - * @return Flink LogicalType - */ - public static LogicalType arrowTypeToFlinkType(Field field) { - ArrowType arrowType = field.getType(); - boolean nullable = field.isNullable(); - - if (arrowType instanceof ArrowType.Int) { - ArrowType.Int intType = (ArrowType.Int) arrowType; - int bitWidth = intType.getBitWidth(); - switch (bitWidth) { - case 8: - return new TinyIntType(nullable); - case 16: - return new SmallIntType(nullable); - case 32: - return new IntType(nullable); - case 64: - return new BigIntType(nullable); - default: - throw new UnsupportedTypeException("Unsupported Arrow Int bit width: " + bitWidth); - } - } else if (arrowType instanceof ArrowType.FloatingPoint) { - ArrowType.FloatingPoint fpType = (ArrowType.FloatingPoint) arrowType; - FloatingPointPrecision precision = fpType.getPrecision(); - switch (precision) { - case SINGLE: - return new FloatType(nullable); - case DOUBLE: - return new DoubleType(nullable); - default: - throw new UnsupportedTypeException("Unsupported Arrow floating point precision: " + precision); - } - } else if (arrowType instanceof ArrowType.Utf8 || arrowType instanceof ArrowType.LargeUtf8) { - return new VarCharType(nullable, VarCharType.MAX_LENGTH); - } else if (arrowType instanceof ArrowType.Bool) { - return new BooleanType(nullable); - } else if (arrowType instanceof ArrowType.Binary) { - return new VarBinaryType(nullable, VarBinaryType.MAX_LENGTH); - } else if (arrowType instanceof ArrowType.LargeBinary) { - return new VarBinaryType(nullable, VarBinaryType.MAX_LENGTH); - } else if (arrowType instanceof ArrowType.FixedSizeBinary) { - ArrowType.FixedSizeBinary fixedBinary = (ArrowType.FixedSizeBinary) arrowType; - return new BinaryType(nullable, fixedBinary.getByteWidth()); - } else if (arrowType instanceof ArrowType.Date) { - return new DateType(nullable); - } else if (arrowType instanceof ArrowType.Timestamp) { - ArrowType.Timestamp tsType = (ArrowType.Timestamp) arrowType; - // Determine precision based on time unit - int precision = getTimestampPrecision(tsType.getUnit()); - return new TimestampType(nullable, precision); - } else if (arrowType instanceof ArrowType.FixedSizeList) { - // Vector type: FixedSizeList - ArrowType.FixedSizeList listType = (ArrowType.FixedSizeList) arrowType; - List children = field.getChildren(); - if (children != null && !children.isEmpty()) { - LogicalType elementType = arrowTypeToFlinkType(children.get(0)); - return new ArrayType(nullable, elementType); - } - throw new UnsupportedTypeException("FixedSizeList must contain child type"); - } else if (arrowType instanceof ArrowType.List || arrowType instanceof ArrowType.LargeList) { - // Regular list type - List children = field.getChildren(); - if (children != null && !children.isEmpty()) { - LogicalType elementType = arrowTypeToFlinkType(children.get(0)); - return new ArrayType(nullable, elementType); - } - throw new UnsupportedTypeException("List must contain child type"); - } else if (arrowType instanceof ArrowType.Struct) { - // Struct type - List structFields = new ArrayList<>(); - for (Field child : field.getChildren()) { - LogicalType childType = arrowTypeToFlinkType(child); - structFields.add(new RowType.RowField(child.getName(), childType)); - } - return new RowType(nullable, structFields); - } else if (arrowType instanceof ArrowType.Null) { - // Null type, map to nullable string - LOG.warn("Arrow Null type mapped to nullable STRING type"); - return new VarCharType(true, VarCharType.MAX_LENGTH); - } + /** + * Convert Arrow Field to Flink LogicalType + * + * @param field Arrow Field + * @return Flink LogicalType + */ + public static LogicalType arrowTypeToFlinkType(Field field) { + ArrowType arrowType = field.getType(); + boolean nullable = field.isNullable(); - throw new UnsupportedTypeException("Unsupported Arrow type: " + arrowType.getClass().getSimpleName()); + if (arrowType instanceof ArrowType.Int) { + ArrowType.Int intType = (ArrowType.Int) arrowType; + int bitWidth = intType.getBitWidth(); + switch (bitWidth) { + case 8: + return new TinyIntType(nullable); + case 16: + return new SmallIntType(nullable); + case 32: + return new IntType(nullable); + case 64: + return new BigIntType(nullable); + default: + throw new UnsupportedTypeException("Unsupported Arrow Int bit width: " + bitWidth); + } + } else if (arrowType instanceof ArrowType.FloatingPoint) { + ArrowType.FloatingPoint fpType = (ArrowType.FloatingPoint) arrowType; + FloatingPointPrecision precision = fpType.getPrecision(); + switch (precision) { + case SINGLE: + return new FloatType(nullable); + case DOUBLE: + return new DoubleType(nullable); + default: + throw new UnsupportedTypeException( + "Unsupported Arrow floating point precision: " + precision); + } + } else if (arrowType instanceof ArrowType.Utf8 || arrowType instanceof ArrowType.LargeUtf8) { + return new VarCharType(nullable, VarCharType.MAX_LENGTH); + } else if (arrowType instanceof ArrowType.Bool) { + return new BooleanType(nullable); + } else if (arrowType instanceof ArrowType.Binary) { + return new VarBinaryType(nullable, VarBinaryType.MAX_LENGTH); + } else if (arrowType instanceof ArrowType.LargeBinary) { + return new VarBinaryType(nullable, VarBinaryType.MAX_LENGTH); + } else if (arrowType instanceof ArrowType.FixedSizeBinary) { + ArrowType.FixedSizeBinary fixedBinary = (ArrowType.FixedSizeBinary) arrowType; + return new BinaryType(nullable, fixedBinary.getByteWidth()); + } else if (arrowType instanceof ArrowType.Date) { + return new DateType(nullable); + } else if (arrowType instanceof ArrowType.Timestamp) { + ArrowType.Timestamp tsType = (ArrowType.Timestamp) arrowType; + // Determine precision based on time unit + int precision = getTimestampPrecision(tsType.getUnit()); + return new TimestampType(nullable, precision); + } else if (arrowType instanceof ArrowType.FixedSizeList) { + // Vector type: FixedSizeList + ArrowType.FixedSizeList listType = (ArrowType.FixedSizeList) arrowType; + List children = field.getChildren(); + if (children != null && !children.isEmpty()) { + LogicalType elementType = arrowTypeToFlinkType(children.get(0)); + return new ArrayType(nullable, elementType); + } + throw new UnsupportedTypeException("FixedSizeList must contain child type"); + } else if (arrowType instanceof ArrowType.List || arrowType instanceof ArrowType.LargeList) { + // Regular list type + List children = field.getChildren(); + if (children != null && !children.isEmpty()) { + LogicalType elementType = arrowTypeToFlinkType(children.get(0)); + return new ArrayType(nullable, elementType); + } + throw new UnsupportedTypeException("List must contain child type"); + } else if (arrowType instanceof ArrowType.Struct) { + // Struct type + List structFields = new ArrayList<>(); + for (Field child : field.getChildren()) { + LogicalType childType = arrowTypeToFlinkType(child); + structFields.add(new RowType.RowField(child.getName(), childType)); + } + return new RowType(nullable, structFields); + } else if (arrowType instanceof ArrowType.Null) { + // Null type, map to nullable string + LOG.warn("Arrow Null type mapped to nullable STRING type"); + return new VarCharType(true, VarCharType.MAX_LENGTH); } - /** - * Convert Flink LogicalType to Arrow Field - * - * @param name Field name - * @param logicalType Flink LogicalType - * @return Arrow Field - */ - public static Field flinkTypeToArrowField(String name, LogicalType logicalType) { - boolean nullable = logicalType.isNullable(); - ArrowType arrowType; - List children = null; + throw new UnsupportedTypeException( + "Unsupported Arrow type: " + arrowType.getClass().getSimpleName()); + } - if (logicalType instanceof TinyIntType) { - arrowType = new ArrowType.Int(8, true); - } else if (logicalType instanceof SmallIntType) { - arrowType = new ArrowType.Int(16, true); - } else if (logicalType instanceof IntType) { - arrowType = new ArrowType.Int(32, true); - } else if (logicalType instanceof BigIntType) { - arrowType = new ArrowType.Int(64, true); - } else if (logicalType instanceof FloatType) { - arrowType = new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); - } else if (logicalType instanceof DoubleType) { - arrowType = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); - } else if (logicalType instanceof VarCharType) { - arrowType = ArrowType.Utf8.INSTANCE; - } else if (logicalType instanceof BooleanType) { - arrowType = ArrowType.Bool.INSTANCE; - } else if (logicalType instanceof VarBinaryType) { - arrowType = ArrowType.Binary.INSTANCE; - } else if (logicalType instanceof BinaryType) { - BinaryType binaryType = (BinaryType) logicalType; - arrowType = new ArrowType.FixedSizeBinary(binaryType.getLength()); - } else if (logicalType instanceof DateType) { - arrowType = new ArrowType.Date(DateUnit.DAY); - } else if (logicalType instanceof TimestampType) { - TimestampType tsType = (TimestampType) logicalType; - TimeUnit timeUnit = getArrowTimeUnit(tsType.getPrecision()); - arrowType = new ArrowType.Timestamp(timeUnit, null); - } else if (logicalType instanceof ArrayType) { - ArrayType arrayType = (ArrayType) logicalType; - LogicalType elementType = arrayType.getElementType(); - Field childField = flinkTypeToArrowField("item", elementType); - children = new ArrayList<>(); - children.add(childField); - // For vector types, use List type - arrowType = ArrowType.List.INSTANCE; - } else if (logicalType instanceof RowType) { - RowType rowType = (RowType) logicalType; - children = new ArrayList<>(); - for (RowType.RowField rowField : rowType.getFields()) { - Field childField = flinkTypeToArrowField(rowField.getName(), rowField.getType()); - children.add(childField); - } - arrowType = ArrowType.Struct.INSTANCE; - } else { - throw new UnsupportedTypeException("Unsupported Flink type: " + logicalType.getClass().getSimpleName()); - } + /** + * Convert Flink LogicalType to Arrow Field + * + * @param name Field name + * @param logicalType Flink LogicalType + * @return Arrow Field + */ + public static Field flinkTypeToArrowField(String name, LogicalType logicalType) { + boolean nullable = logicalType.isNullable(); + ArrowType arrowType; + List children = null; - FieldType fieldType = new FieldType(nullable, arrowType, null); - return new Field(name, fieldType, children); + if (logicalType instanceof TinyIntType) { + arrowType = new ArrowType.Int(8, true); + } else if (logicalType instanceof SmallIntType) { + arrowType = new ArrowType.Int(16, true); + } else if (logicalType instanceof IntType) { + arrowType = new ArrowType.Int(32, true); + } else if (logicalType instanceof BigIntType) { + arrowType = new ArrowType.Int(64, true); + } else if (logicalType instanceof FloatType) { + arrowType = new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + } else if (logicalType instanceof DoubleType) { + arrowType = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + } else if (logicalType instanceof VarCharType) { + arrowType = ArrowType.Utf8.INSTANCE; + } else if (logicalType instanceof BooleanType) { + arrowType = ArrowType.Bool.INSTANCE; + } else if (logicalType instanceof VarBinaryType) { + arrowType = ArrowType.Binary.INSTANCE; + } else if (logicalType instanceof BinaryType) { + BinaryType binaryType = (BinaryType) logicalType; + arrowType = new ArrowType.FixedSizeBinary(binaryType.getLength()); + } else if (logicalType instanceof DateType) { + arrowType = new ArrowType.Date(DateUnit.DAY); + } else if (logicalType instanceof TimestampType) { + TimestampType tsType = (TimestampType) logicalType; + TimeUnit timeUnit = getArrowTimeUnit(tsType.getPrecision()); + arrowType = new ArrowType.Timestamp(timeUnit, null); + } else if (logicalType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) logicalType; + LogicalType elementType = arrayType.getElementType(); + Field childField = flinkTypeToArrowField("item", elementType); + children = new ArrayList<>(); + children.add(childField); + // For vector types, use List type + arrowType = ArrowType.List.INSTANCE; + } else if (logicalType instanceof RowType) { + RowType rowType = (RowType) logicalType; + children = new ArrayList<>(); + for (RowType.RowField rowField : rowType.getFields()) { + Field childField = flinkTypeToArrowField(rowField.getName(), rowField.getType()); + children.add(childField); + } + arrowType = ArrowType.Struct.INSTANCE; + } else { + throw new UnsupportedTypeException( + "Unsupported Flink type: " + logicalType.getClass().getSimpleName()); } - /** - * Create vector field (FixedSizeList) - * - * @param name Field name - * @param dimension Vector dimension - * @param nullable Whether nullable - * @return Arrow Field - */ - public static Field createVectorField(String name, int dimension, boolean nullable) { - ArrowType elementType = new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); - Field elementField = new Field("item", new FieldType(false, elementType, null), null); - - ArrowType listType = new ArrowType.FixedSizeList(dimension); - List children = new ArrayList<>(); - children.add(elementField); - - return new Field(name, new FieldType(nullable, listType, null), children); - } + FieldType fieldType = new FieldType(nullable, arrowType, null); + return new Field(name, fieldType, children); + } + + /** + * Create vector field (FixedSizeList) + * + * @param name Field name + * @param dimension Vector dimension + * @param nullable Whether nullable + * @return Arrow Field + */ + public static Field createVectorField(String name, int dimension, boolean nullable) { + ArrowType elementType = new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + Field elementField = new Field("item", new FieldType(false, elementType, null), null); + + ArrowType listType = new ArrowType.FixedSizeList(dimension); + List children = new ArrayList<>(); + children.add(elementField); + + return new Field(name, new FieldType(nullable, listType, null), children); + } - /** - * Create Float64 vector field (FixedSizeList) - * - * @param name Field name - * @param dimension Vector dimension - * @param nullable Whether nullable - * @return Arrow Field - */ - public static Field createFloat64VectorField(String name, int dimension, boolean nullable) { - ArrowType elementType = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); - Field elementField = new Field("item", new FieldType(false, elementType, null), null); - - ArrowType listType = new ArrowType.FixedSizeList(dimension); - List children = new ArrayList<>(); - children.add(elementField); - - return new Field(name, new FieldType(nullable, listType, null), children); + /** + * Create Float64 vector field (FixedSizeList) + * + * @param name Field name + * @param dimension Vector dimension + * @param nullable Whether nullable + * @return Arrow Field + */ + public static Field createFloat64VectorField(String name, int dimension, boolean nullable) { + ArrowType elementType = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + Field elementField = new Field("item", new FieldType(false, elementType, null), null); + + ArrowType listType = new ArrowType.FixedSizeList(dimension); + List children = new ArrayList<>(); + children.add(elementField); + + return new Field(name, new FieldType(nullable, listType, null), children); + } + + /** + * Check if field is vector type (FixedSizeList) + * + * @param field Arrow Field + * @return Whether vector type + */ + public static boolean isVectorField(Field field) { + ArrowType arrowType = field.getType(); + if (!(arrowType instanceof ArrowType.FixedSizeList)) { + return false; } - /** - * Check if field is vector type (FixedSizeList) - * - * @param field Arrow Field - * @return Whether vector type - */ - public static boolean isVectorField(Field field) { - ArrowType arrowType = field.getType(); - if (!(arrowType instanceof ArrowType.FixedSizeList)) { - return false; - } - - List children = field.getChildren(); - if (children == null || children.isEmpty()) { - return false; - } - - ArrowType childType = children.get(0).getType(); - if (childType instanceof ArrowType.FloatingPoint) { - FloatingPointPrecision precision = ((ArrowType.FloatingPoint) childType).getPrecision(); - return precision == FloatingPointPrecision.SINGLE || precision == FloatingPointPrecision.DOUBLE; - } - - return false; + List children = field.getChildren(); + if (children == null || children.isEmpty()) { + return false; } - /** - * Get vector field dimension - * - * @param field Arrow Field - * @return Vector dimension, returns -1 if not vector field - */ - public static int getVectorDimension(Field field) { - ArrowType arrowType = field.getType(); - if (arrowType instanceof ArrowType.FixedSizeList) { - return ((ArrowType.FixedSizeList) arrowType).getListSize(); - } - return -1; + ArrowType childType = children.get(0).getType(); + if (childType instanceof ArrowType.FloatingPoint) { + FloatingPointPrecision precision = ((ArrowType.FloatingPoint) childType).getPrecision(); + return precision == FloatingPointPrecision.SINGLE + || precision == FloatingPointPrecision.DOUBLE; } - /** - * Convert Flink DataType to LogicalType - * - * @param dataType Flink DataType - * @return LogicalType - */ - public static LogicalType toLogicalType(DataType dataType) { - return dataType.getLogicalType(); + return false; + } + + /** + * Get vector field dimension + * + * @param field Arrow Field + * @return Vector dimension, returns -1 if not vector field + */ + public static int getVectorDimension(Field field) { + ArrowType arrowType = field.getType(); + if (arrowType instanceof ArrowType.FixedSizeList) { + return ((ArrowType.FixedSizeList) arrowType).getListSize(); } + return -1; + } - /** - * Convert LogicalType to Flink DataType - * - * @param logicalType Flink LogicalType - * @return Flink DataType - */ - public static DataType toDataType(LogicalType logicalType) { - if (logicalType instanceof TinyIntType) { - return DataTypes.TINYINT(); - } else if (logicalType instanceof SmallIntType) { - return DataTypes.SMALLINT(); - } else if (logicalType instanceof IntType) { - return DataTypes.INT(); - } else if (logicalType instanceof BigIntType) { - return DataTypes.BIGINT(); - } else if (logicalType instanceof FloatType) { - return DataTypes.FLOAT(); - } else if (logicalType instanceof DoubleType) { - return DataTypes.DOUBLE(); - } else if (logicalType instanceof VarCharType) { - return DataTypes.STRING(); - } else if (logicalType instanceof BooleanType) { - return DataTypes.BOOLEAN(); - } else if (logicalType instanceof VarBinaryType) { - return DataTypes.BYTES(); - } else if (logicalType instanceof BinaryType) { - BinaryType binaryType = (BinaryType) logicalType; - return DataTypes.BINARY(binaryType.getLength()); - } else if (logicalType instanceof DateType) { - return DataTypes.DATE(); - } else if (logicalType instanceof TimestampType) { - TimestampType tsType = (TimestampType) logicalType; - return DataTypes.TIMESTAMP(tsType.getPrecision()); - } else if (logicalType instanceof ArrayType) { - ArrayType arrayType = (ArrayType) logicalType; - DataType elementDataType = toDataType(arrayType.getElementType()); - return DataTypes.ARRAY(elementDataType); - } else if (logicalType instanceof RowType) { - RowType rowType = (RowType) logicalType; - DataTypes.Field[] fields = rowType.getFields().stream() - .map(f -> DataTypes.FIELD(f.getName(), toDataType(f.getType()))) - .toArray(DataTypes.Field[]::new); - return DataTypes.ROW(fields); - } - - throw new UnsupportedTypeException("Unsupported LogicalType: " + logicalType.getClass().getSimpleName()); + /** + * Convert Flink DataType to LogicalType + * + * @param dataType Flink DataType + * @return LogicalType + */ + public static LogicalType toLogicalType(DataType dataType) { + return dataType.getLogicalType(); + } + + /** + * Convert LogicalType to Flink DataType + * + * @param logicalType Flink LogicalType + * @return Flink DataType + */ + public static DataType toDataType(LogicalType logicalType) { + if (logicalType instanceof TinyIntType) { + return DataTypes.TINYINT(); + } else if (logicalType instanceof SmallIntType) { + return DataTypes.SMALLINT(); + } else if (logicalType instanceof IntType) { + return DataTypes.INT(); + } else if (logicalType instanceof BigIntType) { + return DataTypes.BIGINT(); + } else if (logicalType instanceof FloatType) { + return DataTypes.FLOAT(); + } else if (logicalType instanceof DoubleType) { + return DataTypes.DOUBLE(); + } else if (logicalType instanceof VarCharType) { + return DataTypes.STRING(); + } else if (logicalType instanceof BooleanType) { + return DataTypes.BOOLEAN(); + } else if (logicalType instanceof VarBinaryType) { + return DataTypes.BYTES(); + } else if (logicalType instanceof BinaryType) { + BinaryType binaryType = (BinaryType) logicalType; + return DataTypes.BINARY(binaryType.getLength()); + } else if (logicalType instanceof DateType) { + return DataTypes.DATE(); + } else if (logicalType instanceof TimestampType) { + TimestampType tsType = (TimestampType) logicalType; + return DataTypes.TIMESTAMP(tsType.getPrecision()); + } else if (logicalType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) logicalType; + DataType elementDataType = toDataType(arrayType.getElementType()); + return DataTypes.ARRAY(elementDataType); + } else if (logicalType instanceof RowType) { + RowType rowType = (RowType) logicalType; + DataTypes.Field[] fields = + rowType.getFields().stream() + .map(f -> DataTypes.FIELD(f.getName(), toDataType(f.getType()))) + .toArray(DataTypes.Field[]::new); + return DataTypes.ROW(fields); } - /** - * Get Flink Timestamp precision based on Arrow TimeUnit - */ - private static int getTimestampPrecision(TimeUnit timeUnit) { - switch (timeUnit) { - case SECOND: - return 0; - case MILLISECOND: - return 3; - case MICROSECOND: - return 6; - case NANOSECOND: - return 9; - default: - return 6; // Default microsecond precision - } + throw new UnsupportedTypeException( + "Unsupported LogicalType: " + logicalType.getClass().getSimpleName()); + } + + /** Get Flink Timestamp precision based on Arrow TimeUnit */ + private static int getTimestampPrecision(TimeUnit timeUnit) { + switch (timeUnit) { + case SECOND: + return 0; + case MILLISECOND: + return 3; + case MICROSECOND: + return 6; + case NANOSECOND: + return 9; + default: + return 6; // Default microsecond precision } + } - /** - * Get Arrow TimeUnit based on Flink Timestamp precision - */ - private static TimeUnit getArrowTimeUnit(int precision) { - if (precision <= 0) { - return TimeUnit.SECOND; - } else if (precision <= 3) { - return TimeUnit.MILLISECOND; - } else if (precision <= 6) { - return TimeUnit.MICROSECOND; - } else { - return TimeUnit.NANOSECOND; - } + /** Get Arrow TimeUnit based on Flink Timestamp precision */ + private static TimeUnit getArrowTimeUnit(int precision) { + if (precision <= 0) { + return TimeUnit.SECOND; + } else if (precision <= 3) { + return TimeUnit.MILLISECOND; + } else if (precision <= 6) { + return TimeUnit.MICROSECOND; + } else { + return TimeUnit.NANOSECOND; } + } - /** - * Unsupported type exception - */ - public static class UnsupportedTypeException extends RuntimeException { - public UnsupportedTypeException(String message) { - super(message); - } + /** Unsupported type exception */ + public static class UnsupportedTypeException extends RuntimeException { + public UnsupportedTypeException(String message) { + super(message); + } - public UnsupportedTypeException(String message, Throwable cause) { - super(message, cause); - } + public UnsupportedTypeException(String message, Throwable cause) { + super(message, cause); } + } } diff --git a/src/main/java/org/apache/flink/connector/lance/converter/RowDataConverter.java b/src/main/java/org/apache/flink/connector/lance/converter/RowDataConverter.java index 2c727ea..f14291e 100644 --- a/src/main/java/org/apache/flink/connector/lance/converter/RowDataConverter.java +++ b/src/main/java/org/apache/flink/connector/lance/converter/RowDataConverter.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.converter; import org.apache.flink.table.data.ArrayData; @@ -66,660 +61,627 @@ import org.slf4j.LoggerFactory; import java.io.Serializable; -import java.nio.charset.StandardCharsets; -import java.time.Instant; -import java.time.LocalDate; import java.util.ArrayList; import java.util.List; /** * Converter between RowData and Arrow data. - * + * *

Responsible for bidirectional conversion between Arrow VectorSchemaRoot and Flink RowData. */ public class RowDataConverter implements Serializable { - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(RowDataConverter.class); - - private final RowType rowType; - private final String[] fieldNames; - private final LogicalType[] fieldTypes; - - public RowDataConverter(RowType rowType) { - this.rowType = rowType; - this.fieldNames = rowType.getFieldNames().toArray(new String[0]); - this.fieldTypes = rowType.getFields().stream() - .map(RowType.RowField::getType) - .toArray(LogicalType[]::new); - } + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(RowDataConverter.class); + + private final RowType rowType; + private final String[] fieldNames; + private final LogicalType[] fieldTypes; + + public RowDataConverter(RowType rowType) { + this.rowType = rowType; + this.fieldNames = rowType.getFieldNames().toArray(new String[0]); + this.fieldTypes = + rowType.getFields().stream().map(RowType.RowField::getType).toArray(LogicalType[]::new); + } + + /** + * Convert Arrow VectorSchemaRoot to RowData list + * + * @param root Arrow VectorSchemaRoot + * @return RowData list + */ + public List toRowDataList(VectorSchemaRoot root) { + List rows = new ArrayList<>(); + int rowCount = root.getRowCount(); + + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + GenericRowData rowData = new GenericRowData(fieldTypes.length); + + for (int fieldIndex = 0; fieldIndex < fieldTypes.length; fieldIndex++) { + String fieldName = fieldNames[fieldIndex]; + FieldVector vector = root.getVector(fieldName); - /** - * Convert Arrow VectorSchemaRoot to RowData list - * - * @param root Arrow VectorSchemaRoot - * @return RowData list - */ - public List toRowDataList(VectorSchemaRoot root) { - List rows = new ArrayList<>(); - int rowCount = root.getRowCount(); - - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - GenericRowData rowData = new GenericRowData(fieldTypes.length); - - for (int fieldIndex = 0; fieldIndex < fieldTypes.length; fieldIndex++) { - String fieldName = fieldNames[fieldIndex]; - FieldVector vector = root.getVector(fieldName); - - if (vector == null) { - rowData.setField(fieldIndex, null); - continue; - } - - Object value = readValue(vector, rowIndex, fieldTypes[fieldIndex]); - rowData.setField(fieldIndex, value); - } - - rows.add(rowData); + if (vector == null) { + rowData.setField(fieldIndex, null); + continue; } - - return rows; - } - /** - * Write RowData list to Arrow VectorSchemaRoot - * - * @param rows RowData list - * @param root Arrow VectorSchemaRoot - */ - public void toVectorSchemaRoot(List rows, VectorSchemaRoot root) { - root.allocateNew(); - - for (int rowIndex = 0; rowIndex < rows.size(); rowIndex++) { - RowData rowData = rows.get(rowIndex); - - for (int fieldIndex = 0; fieldIndex < fieldTypes.length; fieldIndex++) { - String fieldName = fieldNames[fieldIndex]; - FieldVector vector = root.getVector(fieldName); - - if (vector == null) { - continue; - } - - Object value = getFieldValue(rowData, fieldIndex, fieldTypes[fieldIndex]); - writeValue(vector, rowIndex, value, fieldTypes[fieldIndex]); - } - } - - root.setRowCount(rows.size()); - } + Object value = readValue(vector, rowIndex, fieldTypes[fieldIndex]); + rowData.setField(fieldIndex, value); + } - /** - * Create VectorSchemaRoot - * - * @param allocator Memory allocator - * @return VectorSchemaRoot - */ - public VectorSchemaRoot createVectorSchemaRoot(BufferAllocator allocator) { - Schema arrowSchema = LanceTypeConverter.toArrowSchema(rowType); - return VectorSchemaRoot.create(arrowSchema, allocator); + rows.add(rowData); } - /** - * Read value from Arrow Vector - */ - private Object readValue(FieldVector vector, int index, LogicalType logicalType) { - if (vector.isNull(index)) { - return null; - } + return rows; + } - if (logicalType instanceof TinyIntType) { - return ((TinyIntVector) vector).get(index); - } else if (logicalType instanceof SmallIntType) { - return ((SmallIntVector) vector).get(index); - } else if (logicalType instanceof IntType) { - return ((IntVector) vector).get(index); - } else if (logicalType instanceof BigIntType) { - return ((BigIntVector) vector).get(index); - } else if (logicalType instanceof FloatType) { - return ((Float4Vector) vector).get(index); - } else if (logicalType instanceof DoubleType) { - return ((Float8Vector) vector).get(index); - } else if (logicalType instanceof VarCharType) { - byte[] bytes = ((VarCharVector) vector).get(index); - return StringData.fromBytes(bytes); - } else if (logicalType instanceof BooleanType) { - return ((BitVector) vector).get(index) == 1; - } else if (logicalType instanceof VarBinaryType) { - return ((VarBinaryVector) vector).get(index); - } else if (logicalType instanceof BinaryType) { - return ((FixedSizeBinaryVector) vector).get(index); - } else if (logicalType instanceof DateType) { - int daysSinceEpoch = ((DateDayVector) vector).get(index); - return daysSinceEpoch; - } else if (logicalType instanceof TimestampType) { - return readTimestamp(vector, index, (TimestampType) logicalType); - } else if (logicalType instanceof ArrayType) { - return readArray(vector, index, (ArrayType) logicalType); - } else if (logicalType instanceof RowType) { - return readStruct(vector, index, (RowType) logicalType); - } + /** + * Write RowData list to Arrow VectorSchemaRoot + * + * @param rows RowData list + * @param root Arrow VectorSchemaRoot + */ + public void toVectorSchemaRoot(List rows, VectorSchemaRoot root) { + root.allocateNew(); - throw new LanceTypeConverter.UnsupportedTypeException( - "Unsupported read type: " + logicalType.getClass().getSimpleName()); - } + for (int rowIndex = 0; rowIndex < rows.size(); rowIndex++) { + RowData rowData = rows.get(rowIndex); - /** - * Read timestamp value - */ - private TimestampData readTimestamp(FieldVector vector, int index, TimestampType tsType) { - long value; - int precision = tsType.getPrecision(); - - if (vector instanceof TimeStampSecVector) { - value = ((TimeStampSecVector) vector).get(index); - return TimestampData.fromEpochMillis(value * 1000); - } else if (vector instanceof TimeStampMilliVector) { - value = ((TimeStampMilliVector) vector).get(index); - return TimestampData.fromEpochMillis(value); - } else if (vector instanceof TimeStampMicroVector) { - value = ((TimeStampMicroVector) vector).get(index); - return TimestampData.fromEpochMillis(value / 1000, (int) ((value % 1000) * 1000)); - } else if (vector instanceof TimeStampNanoVector) { - value = ((TimeStampNanoVector) vector).get(index); - return TimestampData.fromEpochMillis(value / 1000000, (int) (value % 1000000)); + for (int fieldIndex = 0; fieldIndex < fieldTypes.length; fieldIndex++) { + String fieldName = fieldNames[fieldIndex]; + FieldVector vector = root.getVector(fieldName); + + if (vector == null) { + continue; } - throw new LanceTypeConverter.UnsupportedTypeException( - "Unsupported timestamp Vector type: " + vector.getClass().getSimpleName()); + Object value = getFieldValue(rowData, fieldIndex, fieldTypes[fieldIndex]); + writeValue(vector, rowIndex, value, fieldTypes[fieldIndex]); + } } - /** - * Read array value - */ - private ArrayData readArray(FieldVector vector, int index, ArrayType arrayType) { - LogicalType elementType = arrayType.getElementType(); - - if (vector instanceof FixedSizeListVector) { - FixedSizeListVector listVector = (FixedSizeListVector) vector; - int listSize = listVector.getListSize(); - FieldVector dataVector = listVector.getDataVector(); - int startIndex = index * listSize; - - return readArrayData(dataVector, startIndex, listSize, elementType); - } else if (vector instanceof ListVector) { - ListVector listVector = (ListVector) vector; - int startIndex = listVector.getElementStartIndex(index); - int endIndex = listVector.getElementEndIndex(index); - int listSize = endIndex - startIndex; - FieldVector dataVector = listVector.getDataVector(); - - return readArrayData(dataVector, startIndex, listSize, elementType); - } - - throw new LanceTypeConverter.UnsupportedTypeException( - "Unsupported array Vector type: " + vector.getClass().getSimpleName()); + root.setRowCount(rows.size()); + } + + /** + * Create VectorSchemaRoot + * + * @param allocator Memory allocator + * @return VectorSchemaRoot + */ + public VectorSchemaRoot createVectorSchemaRoot(BufferAllocator allocator) { + Schema arrowSchema = LanceTypeConverter.toArrowSchema(rowType); + return VectorSchemaRoot.create(arrowSchema, allocator); + } + + /** Read value from Arrow Vector */ + private Object readValue(FieldVector vector, int index, LogicalType logicalType) { + if (vector.isNull(index)) { + return null; } - /** - * Read array data - */ - private ArrayData readArrayData(FieldVector dataVector, int startIndex, int size, LogicalType elementType) { - if (elementType instanceof FloatType) { - Float4Vector float4Vector = (Float4Vector) dataVector; - Float[] values = new Float[size]; - for (int i = 0; i < size; i++) { - if (float4Vector.isNull(startIndex + i)) { - values[i] = null; - } else { - values[i] = float4Vector.get(startIndex + i); - } - } - return new GenericArrayData(values); - } else if (elementType instanceof DoubleType) { - Double8Vector double8Vector = (Double8Vector) dataVector; - Double[] values = new Double[size]; - for (int i = 0; i < size; i++) { - if (double8Vector.isNull(startIndex + i)) { - values[i] = null; - } else { - values[i] = double8Vector.get(startIndex + i); - } - } - return new GenericArrayData(values); - } else if (elementType instanceof IntType) { - IntVector intVector = (IntVector) dataVector; - Integer[] values = new Integer[size]; - for (int i = 0; i < size; i++) { - if (intVector.isNull(startIndex + i)) { - values[i] = null; - } else { - values[i] = intVector.get(startIndex + i); - } - } - return new GenericArrayData(values); - } else if (elementType instanceof BigIntType) { - BigIntVector bigIntVector = (BigIntVector) dataVector; - Long[] values = new Long[size]; - for (int i = 0; i < size; i++) { - if (bigIntVector.isNull(startIndex + i)) { - values[i] = null; - } else { - values[i] = bigIntVector.get(startIndex + i); - } - } - return new GenericArrayData(values); - } else if (elementType instanceof VarCharType) { - VarCharVector varCharVector = (VarCharVector) dataVector; - StringData[] values = new StringData[size]; - for (int i = 0; i < size; i++) { - if (varCharVector.isNull(startIndex + i)) { - values[i] = null; - } else { - values[i] = StringData.fromBytes(varCharVector.get(startIndex + i)); - } - } - return new GenericArrayData(values); - } + if (logicalType instanceof TinyIntType) { + return ((TinyIntVector) vector).get(index); + } else if (logicalType instanceof SmallIntType) { + return ((SmallIntVector) vector).get(index); + } else if (logicalType instanceof IntType) { + return ((IntVector) vector).get(index); + } else if (logicalType instanceof BigIntType) { + return ((BigIntVector) vector).get(index); + } else if (logicalType instanceof FloatType) { + return ((Float4Vector) vector).get(index); + } else if (logicalType instanceof DoubleType) { + return ((Float8Vector) vector).get(index); + } else if (logicalType instanceof VarCharType) { + byte[] bytes = ((VarCharVector) vector).get(index); + return StringData.fromBytes(bytes); + } else if (logicalType instanceof BooleanType) { + return ((BitVector) vector).get(index) == 1; + } else if (logicalType instanceof VarBinaryType) { + return ((VarBinaryVector) vector).get(index); + } else if (logicalType instanceof BinaryType) { + return ((FixedSizeBinaryVector) vector).get(index); + } else if (logicalType instanceof DateType) { + int daysSinceEpoch = ((DateDayVector) vector).get(index); + return daysSinceEpoch; + } else if (logicalType instanceof TimestampType) { + return readTimestamp(vector, index, (TimestampType) logicalType); + } else if (logicalType instanceof ArrayType) { + return readArray(vector, index, (ArrayType) logicalType); + } else if (logicalType instanceof RowType) { + return readStruct(vector, index, (RowType) logicalType); + } - throw new LanceTypeConverter.UnsupportedTypeException( - "Unsupported array element type: " + elementType.getClass().getSimpleName()); + throw new LanceTypeConverter.UnsupportedTypeException( + "Unsupported read type: " + logicalType.getClass().getSimpleName()); + } + + /** Read timestamp value */ + private TimestampData readTimestamp(FieldVector vector, int index, TimestampType tsType) { + long value; + int precision = tsType.getPrecision(); + + if (vector instanceof TimeStampSecVector) { + value = ((TimeStampSecVector) vector).get(index); + return TimestampData.fromEpochMillis(value * 1000); + } else if (vector instanceof TimeStampMilliVector) { + value = ((TimeStampMilliVector) vector).get(index); + return TimestampData.fromEpochMillis(value); + } else if (vector instanceof TimeStampMicroVector) { + value = ((TimeStampMicroVector) vector).get(index); + return TimestampData.fromEpochMillis(value / 1000, (int) ((value % 1000) * 1000)); + } else if (vector instanceof TimeStampNanoVector) { + value = ((TimeStampNanoVector) vector).get(index); + return TimestampData.fromEpochMillis(value / 1000000, (int) (value % 1000000)); } - /** - * Internal class for handling Double type Vector (alias for Float8Vector) - */ - private static class Double8Vector { - private final Float8Vector vector; + throw new LanceTypeConverter.UnsupportedTypeException( + "Unsupported timestamp Vector type: " + vector.getClass().getSimpleName()); + } + + /** Read array value */ + private ArrayData readArray(FieldVector vector, int index, ArrayType arrayType) { + LogicalType elementType = arrayType.getElementType(); + + if (vector instanceof FixedSizeListVector) { + FixedSizeListVector listVector = (FixedSizeListVector) vector; + int listSize = listVector.getListSize(); + FieldVector dataVector = listVector.getDataVector(); + int startIndex = index * listSize; + + return readArrayData(dataVector, startIndex, listSize, elementType); + } else if (vector instanceof ListVector) { + ListVector listVector = (ListVector) vector; + int startIndex = listVector.getElementStartIndex(index); + int endIndex = listVector.getElementEndIndex(index); + int listSize = endIndex - startIndex; + FieldVector dataVector = listVector.getDataVector(); + + return readArrayData(dataVector, startIndex, listSize, elementType); + } - Double8Vector(FieldVector vector) { - this.vector = (Float8Vector) vector; + throw new LanceTypeConverter.UnsupportedTypeException( + "Unsupported array Vector type: " + vector.getClass().getSimpleName()); + } + + /** Read array data */ + private ArrayData readArrayData( + FieldVector dataVector, int startIndex, int size, LogicalType elementType) { + if (elementType instanceof FloatType) { + Float4Vector float4Vector = (Float4Vector) dataVector; + Float[] values = new Float[size]; + for (int i = 0; i < size; i++) { + if (float4Vector.isNull(startIndex + i)) { + values[i] = null; + } else { + values[i] = float4Vector.get(startIndex + i); } - - boolean isNull(int index) { - return vector.isNull(index); + } + return new GenericArrayData(values); + } else if (elementType instanceof DoubleType) { + Double8Vector double8Vector = (Double8Vector) dataVector; + Double[] values = new Double[size]; + for (int i = 0; i < size; i++) { + if (double8Vector.isNull(startIndex + i)) { + values[i] = null; + } else { + values[i] = double8Vector.get(startIndex + i); } - - double get(int index) { - return vector.get(index); + } + return new GenericArrayData(values); + } else if (elementType instanceof IntType) { + IntVector intVector = (IntVector) dataVector; + Integer[] values = new Integer[size]; + for (int i = 0; i < size; i++) { + if (intVector.isNull(startIndex + i)) { + values[i] = null; + } else { + values[i] = intVector.get(startIndex + i); } - } - - /** - * Read struct value - */ - private RowData readStruct(FieldVector vector, int index, RowType rowType) { - StructVector structVector = (StructVector) vector; - List fields = rowType.getFields(); - GenericRowData rowData = new GenericRowData(fields.size()); - - for (int i = 0; i < fields.size(); i++) { - RowType.RowField field = fields.get(i); - FieldVector childVector = structVector.getChild(field.getName()); - if (childVector == null) { - rowData.setField(i, null); - } else { - Object value = readValue(childVector, index, field.getType()); - rowData.setField(i, value); - } + } + return new GenericArrayData(values); + } else if (elementType instanceof BigIntType) { + BigIntVector bigIntVector = (BigIntVector) dataVector; + Long[] values = new Long[size]; + for (int i = 0; i < size; i++) { + if (bigIntVector.isNull(startIndex + i)) { + values[i] = null; + } else { + values[i] = bigIntVector.get(startIndex + i); } - - return rowData; + } + return new GenericArrayData(values); + } else if (elementType instanceof VarCharType) { + VarCharVector varCharVector = (VarCharVector) dataVector; + StringData[] values = new StringData[size]; + for (int i = 0; i < size; i++) { + if (varCharVector.isNull(startIndex + i)) { + values[i] = null; + } else { + values[i] = StringData.fromBytes(varCharVector.get(startIndex + i)); + } + } + return new GenericArrayData(values); } - /** - * Get field value from RowData - */ - private Object getFieldValue(RowData rowData, int index, LogicalType logicalType) { - if (rowData.isNullAt(index)) { - return null; - } + throw new LanceTypeConverter.UnsupportedTypeException( + "Unsupported array element type: " + elementType.getClass().getSimpleName()); + } - if (logicalType instanceof TinyIntType) { - return rowData.getByte(index); - } else if (logicalType instanceof SmallIntType) { - return rowData.getShort(index); - } else if (logicalType instanceof IntType) { - return rowData.getInt(index); - } else if (logicalType instanceof BigIntType) { - return rowData.getLong(index); - } else if (logicalType instanceof FloatType) { - return rowData.getFloat(index); - } else if (logicalType instanceof DoubleType) { - return rowData.getDouble(index); - } else if (logicalType instanceof VarCharType) { - return rowData.getString(index); - } else if (logicalType instanceof BooleanType) { - return rowData.getBoolean(index); - } else if (logicalType instanceof VarBinaryType || logicalType instanceof BinaryType) { - return rowData.getBinary(index); - } else if (logicalType instanceof DateType) { - return rowData.getInt(index); - } else if (logicalType instanceof TimestampType) { - TimestampType tsType = (TimestampType) logicalType; - return rowData.getTimestamp(index, tsType.getPrecision()); - } else if (logicalType instanceof ArrayType) { - return rowData.getArray(index); - } else if (logicalType instanceof RowType) { - RowType nestedRowType = (RowType) logicalType; - return rowData.getRow(index, nestedRowType.getFieldCount()); - } + /** Internal class for handling Double type Vector (alias for Float8Vector) */ + private static class Double8Vector { + private final Float8Vector vector; - throw new LanceTypeConverter.UnsupportedTypeException( - "Unsupported get type: " + logicalType.getClass().getSimpleName()); + Double8Vector(FieldVector vector) { + this.vector = (Float8Vector) vector; } - /** - * Write value to Arrow Vector - */ - private void writeValue(FieldVector vector, int index, Object value, LogicalType logicalType) { - if (value == null) { - setNull(vector, index); - return; - } - - if (logicalType instanceof TinyIntType) { - ((TinyIntVector) vector).setSafe(index, (byte) value); - } else if (logicalType instanceof SmallIntType) { - ((SmallIntVector) vector).setSafe(index, (short) value); - } else if (logicalType instanceof IntType) { - ((IntVector) vector).setSafe(index, (int) value); - } else if (logicalType instanceof BigIntType) { - ((BigIntVector) vector).setSafe(index, (long) value); - } else if (logicalType instanceof FloatType) { - ((Float4Vector) vector).setSafe(index, (float) value); - } else if (logicalType instanceof DoubleType) { - ((Float8Vector) vector).setSafe(index, (double) value); - } else if (logicalType instanceof VarCharType) { - StringData stringData = (StringData) value; - ((VarCharVector) vector).setSafe(index, stringData.toBytes()); - } else if (logicalType instanceof BooleanType) { - ((BitVector) vector).setSafe(index, (boolean) value ? 1 : 0); - } else if (logicalType instanceof VarBinaryType) { - ((VarBinaryVector) vector).setSafe(index, (byte[]) value); - } else if (logicalType instanceof BinaryType) { - ((FixedSizeBinaryVector) vector).setSafe(index, (byte[]) value); - } else if (logicalType instanceof DateType) { - ((DateDayVector) vector).setSafe(index, (int) value); - } else if (logicalType instanceof TimestampType) { - writeTimestamp(vector, index, (TimestampData) value, (TimestampType) logicalType); - } else if (logicalType instanceof ArrayType) { - writeArray(vector, index, (ArrayData) value, (ArrayType) logicalType); - } else if (logicalType instanceof RowType) { - writeStruct(vector, index, (RowData) value, (RowType) logicalType); - } else { - throw new LanceTypeConverter.UnsupportedTypeException( - "Unsupported write type: " + logicalType.getClass().getSimpleName()); - } + boolean isNull(int index) { + return vector.isNull(index); } - /** - * Set null value - */ - private void setNull(FieldVector vector, int index) { - if (vector instanceof TinyIntVector) { - ((TinyIntVector) vector).setNull(index); - } else if (vector instanceof SmallIntVector) { - ((SmallIntVector) vector).setNull(index); - } else if (vector instanceof IntVector) { - ((IntVector) vector).setNull(index); - } else if (vector instanceof BigIntVector) { - ((BigIntVector) vector).setNull(index); - } else if (vector instanceof Float4Vector) { - ((Float4Vector) vector).setNull(index); - } else if (vector instanceof Float8Vector) { - ((Float8Vector) vector).setNull(index); - } else if (vector instanceof VarCharVector) { - ((VarCharVector) vector).setNull(index); - } else if (vector instanceof BitVector) { - ((BitVector) vector).setNull(index); - } else if (vector instanceof VarBinaryVector) { - ((VarBinaryVector) vector).setNull(index); - } else if (vector instanceof FixedSizeBinaryVector) { - ((FixedSizeBinaryVector) vector).setNull(index); - } else if (vector instanceof DateDayVector) { - ((DateDayVector) vector).setNull(index); - } else if (vector instanceof TimeStampSecVector) { - ((TimeStampSecVector) vector).setNull(index); - } else if (vector instanceof TimeStampMilliVector) { - ((TimeStampMilliVector) vector).setNull(index); - } else if (vector instanceof TimeStampMicroVector) { - ((TimeStampMicroVector) vector).setNull(index); - } else if (vector instanceof TimeStampNanoVector) { - ((TimeStampNanoVector) vector).setNull(index); - } else if (vector instanceof FixedSizeListVector) { - ((FixedSizeListVector) vector).setNull(index); - } else if (vector instanceof ListVector) { - ((ListVector) vector).setNull(index); - } else if (vector instanceof StructVector) { - ((StructVector) vector).setNull(index); - } + double get(int index) { + return vector.get(index); } - - /** - * Write timestamp value - */ - private void writeTimestamp(FieldVector vector, int index, TimestampData tsData, TimestampType tsType) { - long millis = tsData.getMillisecond(); - int nanos = tsData.getNanoOfMillisecond(); - - if (vector instanceof TimeStampSecVector) { - ((TimeStampSecVector) vector).setSafe(index, millis / 1000); - } else if (vector instanceof TimeStampMilliVector) { - ((TimeStampMilliVector) vector).setSafe(index, millis); - } else if (vector instanceof TimeStampMicroVector) { - long micros = millis * 1000 + nanos / 1000; - ((TimeStampMicroVector) vector).setSafe(index, micros); - } else if (vector instanceof TimeStampNanoVector) { - long totalNanos = millis * 1000000 + nanos; - ((TimeStampNanoVector) vector).setSafe(index, totalNanos); - } else { - throw new LanceTypeConverter.UnsupportedTypeException( - "Unsupported timestamp Vector type: " + vector.getClass().getSimpleName()); - } + } + + /** Read struct value */ + private RowData readStruct(FieldVector vector, int index, RowType rowType) { + StructVector structVector = (StructVector) vector; + List fields = rowType.getFields(); + GenericRowData rowData = new GenericRowData(fields.size()); + + for (int i = 0; i < fields.size(); i++) { + RowType.RowField field = fields.get(i); + FieldVector childVector = structVector.getChild(field.getName()); + if (childVector == null) { + rowData.setField(i, null); + } else { + Object value = readValue(childVector, index, field.getType()); + rowData.setField(i, value); + } } - /** - * Write array value - */ - private void writeArray(FieldVector vector, int index, ArrayData arrayData, ArrayType arrayType) { - LogicalType elementType = arrayType.getElementType(); - int size = arrayData.size(); - - if (vector instanceof FixedSizeListVector) { - FixedSizeListVector listVector = (FixedSizeListVector) vector; - int listSize = listVector.getListSize(); - - if (size != listSize) { - throw new IllegalArgumentException( - "Array size " + size + " does not match FixedSizeList size " + listSize); - } - - FieldVector dataVector = listVector.getDataVector(); - int startIndex = index * listSize; - - writeArrayData(dataVector, startIndex, arrayData, elementType); - listVector.setNotNull(index); - } else if (vector instanceof ListVector) { - ListVector listVector = (ListVector) vector; - listVector.startNewValue(index); - - FieldVector dataVector = listVector.getDataVector(); - int startIndex = listVector.getElementStartIndex(index); - - writeArrayData(dataVector, startIndex, arrayData, elementType); - listVector.endValue(index, size); - } else { - throw new LanceTypeConverter.UnsupportedTypeException( - "Unsupported array Vector type: " + vector.getClass().getSimpleName()); - } + return rowData; + } + + /** Get field value from RowData */ + private Object getFieldValue(RowData rowData, int index, LogicalType logicalType) { + if (rowData.isNullAt(index)) { + return null; } - /** - * Write array data - */ - private void writeArrayData(FieldVector dataVector, int startIndex, ArrayData arrayData, LogicalType elementType) { - int size = arrayData.size(); - - if (elementType instanceof FloatType) { - Float4Vector float4Vector = (Float4Vector) dataVector; - for (int i = 0; i < size; i++) { - if (arrayData.isNullAt(i)) { - float4Vector.setNull(startIndex + i); - } else { - float4Vector.setSafe(startIndex + i, arrayData.getFloat(i)); - } - } - } else if (elementType instanceof DoubleType) { - Float8Vector float8Vector = (Float8Vector) dataVector; - for (int i = 0; i < size; i++) { - if (arrayData.isNullAt(i)) { - float8Vector.setNull(startIndex + i); - } else { - float8Vector.setSafe(startIndex + i, arrayData.getDouble(i)); - } - } - } else if (elementType instanceof IntType) { - IntVector intVector = (IntVector) dataVector; - for (int i = 0; i < size; i++) { - if (arrayData.isNullAt(i)) { - intVector.setNull(startIndex + i); - } else { - intVector.setSafe(startIndex + i, arrayData.getInt(i)); - } - } - } else if (elementType instanceof BigIntType) { - BigIntVector bigIntVector = (BigIntVector) dataVector; - for (int i = 0; i < size; i++) { - if (arrayData.isNullAt(i)) { - bigIntVector.setNull(startIndex + i); - } else { - bigIntVector.setSafe(startIndex + i, arrayData.getLong(i)); - } - } - } else if (elementType instanceof VarCharType) { - VarCharVector varCharVector = (VarCharVector) dataVector; - for (int i = 0; i < size; i++) { - if (arrayData.isNullAt(i)) { - varCharVector.setNull(startIndex + i); - } else { - StringData stringData = arrayData.getString(i); - varCharVector.setSafe(startIndex + i, stringData.toBytes()); - } - } - } else { - throw new LanceTypeConverter.UnsupportedTypeException( - "Unsupported array element type: " + elementType.getClass().getSimpleName()); - } + if (logicalType instanceof TinyIntType) { + return rowData.getByte(index); + } else if (logicalType instanceof SmallIntType) { + return rowData.getShort(index); + } else if (logicalType instanceof IntType) { + return rowData.getInt(index); + } else if (logicalType instanceof BigIntType) { + return rowData.getLong(index); + } else if (logicalType instanceof FloatType) { + return rowData.getFloat(index); + } else if (logicalType instanceof DoubleType) { + return rowData.getDouble(index); + } else if (logicalType instanceof VarCharType) { + return rowData.getString(index); + } else if (logicalType instanceof BooleanType) { + return rowData.getBoolean(index); + } else if (logicalType instanceof VarBinaryType || logicalType instanceof BinaryType) { + return rowData.getBinary(index); + } else if (logicalType instanceof DateType) { + return rowData.getInt(index); + } else if (logicalType instanceof TimestampType) { + TimestampType tsType = (TimestampType) logicalType; + return rowData.getTimestamp(index, tsType.getPrecision()); + } else if (logicalType instanceof ArrayType) { + return rowData.getArray(index); + } else if (logicalType instanceof RowType) { + RowType nestedRowType = (RowType) logicalType; + return rowData.getRow(index, nestedRowType.getFieldCount()); } - /** - * Write struct value - */ - private void writeStruct(FieldVector vector, int index, RowData rowData, RowType rowType) { - StructVector structVector = (StructVector) vector; - List fields = rowType.getFields(); - - for (int i = 0; i < fields.size(); i++) { - RowType.RowField field = fields.get(i); - FieldVector childVector = structVector.getChild(field.getName()); - if (childVector != null) { - Object value = getFieldValue(rowData, i, field.getType()); - writeValue(childVector, index, value, field.getType()); - } - } + throw new LanceTypeConverter.UnsupportedTypeException( + "Unsupported get type: " + logicalType.getClass().getSimpleName()); + } - structVector.setIndexDefined(index); + /** Write value to Arrow Vector */ + private void writeValue(FieldVector vector, int index, Object value, LogicalType logicalType) { + if (value == null) { + setNull(vector, index); + return; } - /** - * Convert float array to ArrayData - * - * @param vector float array - * @return ArrayData - */ - public static ArrayData toArrayData(float[] vector) { - if (vector == null) { - return null; - } - Float[] boxed = new Float[vector.length]; - for (int i = 0; i < vector.length; i++) { - boxed[i] = vector[i]; - } - return new GenericArrayData(boxed); + if (logicalType instanceof TinyIntType) { + ((TinyIntVector) vector).setSafe(index, (byte) value); + } else if (logicalType instanceof SmallIntType) { + ((SmallIntVector) vector).setSafe(index, (short) value); + } else if (logicalType instanceof IntType) { + ((IntVector) vector).setSafe(index, (int) value); + } else if (logicalType instanceof BigIntType) { + ((BigIntVector) vector).setSafe(index, (long) value); + } else if (logicalType instanceof FloatType) { + ((Float4Vector) vector).setSafe(index, (float) value); + } else if (logicalType instanceof DoubleType) { + ((Float8Vector) vector).setSafe(index, (double) value); + } else if (logicalType instanceof VarCharType) { + StringData stringData = (StringData) value; + ((VarCharVector) vector).setSafe(index, stringData.toBytes()); + } else if (logicalType instanceof BooleanType) { + ((BitVector) vector).setSafe(index, (boolean) value ? 1 : 0); + } else if (logicalType instanceof VarBinaryType) { + ((VarBinaryVector) vector).setSafe(index, (byte[]) value); + } else if (logicalType instanceof BinaryType) { + ((FixedSizeBinaryVector) vector).setSafe(index, (byte[]) value); + } else if (logicalType instanceof DateType) { + ((DateDayVector) vector).setSafe(index, (int) value); + } else if (logicalType instanceof TimestampType) { + writeTimestamp(vector, index, (TimestampData) value, (TimestampType) logicalType); + } else if (logicalType instanceof ArrayType) { + writeArray(vector, index, (ArrayData) value, (ArrayType) logicalType); + } else if (logicalType instanceof RowType) { + writeStruct(vector, index, (RowData) value, (RowType) logicalType); + } else { + throw new LanceTypeConverter.UnsupportedTypeException( + "Unsupported write type: " + logicalType.getClass().getSimpleName()); } - - /** - * Convert double array to ArrayData - * - * @param vector double array - * @return ArrayData - */ - public static ArrayData toArrayData(double[] vector) { - if (vector == null) { - return null; - } - Double[] boxed = new Double[vector.length]; - for (int i = 0; i < vector.length; i++) { - boxed[i] = vector[i]; - } - return new GenericArrayData(boxed); + } + + /** Set null value */ + private void setNull(FieldVector vector, int index) { + if (vector instanceof TinyIntVector) { + ((TinyIntVector) vector).setNull(index); + } else if (vector instanceof SmallIntVector) { + ((SmallIntVector) vector).setNull(index); + } else if (vector instanceof IntVector) { + ((IntVector) vector).setNull(index); + } else if (vector instanceof BigIntVector) { + ((BigIntVector) vector).setNull(index); + } else if (vector instanceof Float4Vector) { + ((Float4Vector) vector).setNull(index); + } else if (vector instanceof Float8Vector) { + ((Float8Vector) vector).setNull(index); + } else if (vector instanceof VarCharVector) { + ((VarCharVector) vector).setNull(index); + } else if (vector instanceof BitVector) { + ((BitVector) vector).setNull(index); + } else if (vector instanceof VarBinaryVector) { + ((VarBinaryVector) vector).setNull(index); + } else if (vector instanceof FixedSizeBinaryVector) { + ((FixedSizeBinaryVector) vector).setNull(index); + } else if (vector instanceof DateDayVector) { + ((DateDayVector) vector).setNull(index); + } else if (vector instanceof TimeStampSecVector) { + ((TimeStampSecVector) vector).setNull(index); + } else if (vector instanceof TimeStampMilliVector) { + ((TimeStampMilliVector) vector).setNull(index); + } else if (vector instanceof TimeStampMicroVector) { + ((TimeStampMicroVector) vector).setNull(index); + } else if (vector instanceof TimeStampNanoVector) { + ((TimeStampNanoVector) vector).setNull(index); + } else if (vector instanceof FixedSizeListVector) { + ((FixedSizeListVector) vector).setNull(index); + } else if (vector instanceof ListVector) { + ((ListVector) vector).setNull(index); + } else if (vector instanceof StructVector) { + ((StructVector) vector).setNull(index); } - - /** - * Convert ArrayData to float array - * - * @param arrayData ArrayData - * @return float array - */ - public static float[] toFloatArray(ArrayData arrayData) { - if (arrayData == null) { - return null; + } + + /** Write timestamp value */ + private void writeTimestamp( + FieldVector vector, int index, TimestampData tsData, TimestampType tsType) { + long millis = tsData.getMillisecond(); + int nanos = tsData.getNanoOfMillisecond(); + + if (vector instanceof TimeStampSecVector) { + ((TimeStampSecVector) vector).setSafe(index, millis / 1000); + } else if (vector instanceof TimeStampMilliVector) { + ((TimeStampMilliVector) vector).setSafe(index, millis); + } else if (vector instanceof TimeStampMicroVector) { + long micros = millis * 1000 + nanos / 1000; + ((TimeStampMicroVector) vector).setSafe(index, micros); + } else if (vector instanceof TimeStampNanoVector) { + long totalNanos = millis * 1000000 + nanos; + ((TimeStampNanoVector) vector).setSafe(index, totalNanos); + } else { + throw new LanceTypeConverter.UnsupportedTypeException( + "Unsupported timestamp Vector type: " + vector.getClass().getSimpleName()); + } + } + + /** Write array value */ + private void writeArray(FieldVector vector, int index, ArrayData arrayData, ArrayType arrayType) { + LogicalType elementType = arrayType.getElementType(); + int size = arrayData.size(); + + if (vector instanceof FixedSizeListVector) { + FixedSizeListVector listVector = (FixedSizeListVector) vector; + int listSize = listVector.getListSize(); + + if (size != listSize) { + throw new IllegalArgumentException( + "Array size " + size + " does not match FixedSizeList size " + listSize); + } + + FieldVector dataVector = listVector.getDataVector(); + int startIndex = index * listSize; + + writeArrayData(dataVector, startIndex, arrayData, elementType); + listVector.setNotNull(index); + } else if (vector instanceof ListVector) { + ListVector listVector = (ListVector) vector; + listVector.startNewValue(index); + + FieldVector dataVector = listVector.getDataVector(); + int startIndex = listVector.getElementStartIndex(index); + + writeArrayData(dataVector, startIndex, arrayData, elementType); + listVector.endValue(index, size); + } else { + throw new LanceTypeConverter.UnsupportedTypeException( + "Unsupported array Vector type: " + vector.getClass().getSimpleName()); + } + } + + /** Write array data */ + private void writeArrayData( + FieldVector dataVector, int startIndex, ArrayData arrayData, LogicalType elementType) { + int size = arrayData.size(); + + if (elementType instanceof FloatType) { + Float4Vector float4Vector = (Float4Vector) dataVector; + for (int i = 0; i < size; i++) { + if (arrayData.isNullAt(i)) { + float4Vector.setNull(startIndex + i); + } else { + float4Vector.setSafe(startIndex + i, arrayData.getFloat(i)); } - int size = arrayData.size(); - float[] result = new float[size]; - for (int i = 0; i < size; i++) { - result[i] = arrayData.getFloat(i); + } + } else if (elementType instanceof DoubleType) { + Float8Vector float8Vector = (Float8Vector) dataVector; + for (int i = 0; i < size; i++) { + if (arrayData.isNullAt(i)) { + float8Vector.setNull(startIndex + i); + } else { + float8Vector.setSafe(startIndex + i, arrayData.getDouble(i)); } - return result; - } - - /** - * Convert ArrayData to double array - * - * @param arrayData ArrayData - * @return double array - */ - public static double[] toDoubleArray(ArrayData arrayData) { - if (arrayData == null) { - return null; + } + } else if (elementType instanceof IntType) { + IntVector intVector = (IntVector) dataVector; + for (int i = 0; i < size; i++) { + if (arrayData.isNullAt(i)) { + intVector.setNull(startIndex + i); + } else { + intVector.setSafe(startIndex + i, arrayData.getInt(i)); + } + } + } else if (elementType instanceof BigIntType) { + BigIntVector bigIntVector = (BigIntVector) dataVector; + for (int i = 0; i < size; i++) { + if (arrayData.isNullAt(i)) { + bigIntVector.setNull(startIndex + i); + } else { + bigIntVector.setSafe(startIndex + i, arrayData.getLong(i)); } - int size = arrayData.size(); - double[] result = new double[size]; - for (int i = 0; i < size; i++) { - result[i] = arrayData.getDouble(i); + } + } else if (elementType instanceof VarCharType) { + VarCharVector varCharVector = (VarCharVector) dataVector; + for (int i = 0; i < size; i++) { + if (arrayData.isNullAt(i)) { + varCharVector.setNull(startIndex + i); + } else { + StringData stringData = arrayData.getString(i); + varCharVector.setSafe(startIndex + i, stringData.toBytes()); } - return result; + } + } else { + throw new LanceTypeConverter.UnsupportedTypeException( + "Unsupported array element type: " + elementType.getClass().getSimpleName()); } - - /** - * Get RowType - */ - public RowType getRowType() { - return rowType; + } + + /** Write struct value */ + private void writeStruct(FieldVector vector, int index, RowData rowData, RowType rowType) { + StructVector structVector = (StructVector) vector; + List fields = rowType.getFields(); + + for (int i = 0; i < fields.size(); i++) { + RowType.RowField field = fields.get(i); + FieldVector childVector = structVector.getChild(field.getName()); + if (childVector != null) { + Object value = getFieldValue(rowData, i, field.getType()); + writeValue(childVector, index, value, field.getType()); + } } - /** - * Get field name array - */ - public String[] getFieldNames() { - return fieldNames; + structVector.setIndexDefined(index); + } + + /** + * Convert float array to ArrayData + * + * @param vector float array + * @return ArrayData + */ + public static ArrayData toArrayData(float[] vector) { + if (vector == null) { + return null; } - - /** - * Get field type array - */ - public LogicalType[] getFieldTypes() { - return fieldTypes; + Float[] boxed = new Float[vector.length]; + for (int i = 0; i < vector.length; i++) { + boxed[i] = vector[i]; + } + return new GenericArrayData(boxed); + } + + /** + * Convert double array to ArrayData + * + * @param vector double array + * @return ArrayData + */ + public static ArrayData toArrayData(double[] vector) { + if (vector == null) { + return null; + } + Double[] boxed = new Double[vector.length]; + for (int i = 0; i < vector.length; i++) { + boxed[i] = vector[i]; + } + return new GenericArrayData(boxed); + } + + /** + * Convert ArrayData to float array + * + * @param arrayData ArrayData + * @return float array + */ + public static float[] toFloatArray(ArrayData arrayData) { + if (arrayData == null) { + return null; + } + int size = arrayData.size(); + float[] result = new float[size]; + for (int i = 0; i < size; i++) { + result[i] = arrayData.getFloat(i); + } + return result; + } + + /** + * Convert ArrayData to double array + * + * @param arrayData ArrayData + * @return double array + */ + public static double[] toDoubleArray(ArrayData arrayData) { + if (arrayData == null) { + return null; + } + int size = arrayData.size(); + double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = arrayData.getDouble(i); } + return result; + } + + /** Get RowType */ + public RowType getRowType() { + return rowType; + } + + /** Get field name array */ + public String[] getFieldNames() { + return fieldNames; + } + + /** Get field type array */ + public LogicalType[] getFieldTypes() { + return fieldTypes; + } } diff --git a/src/main/java/org/apache/flink/connector/lance/table/LanceCatalog.java b/src/main/java/org/apache/flink/connector/lance/table/LanceCatalog.java index 74fed94..bb719cc 100644 --- a/src/main/java/org/apache/flink/connector/lance/table/LanceCatalog.java +++ b/src/main/java/org/apache/flink/connector/lance/table/LanceCatalog.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,11 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.table; import org.apache.flink.connector.lance.converter.LanceTypeConverter; -import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.Schema; import org.apache.flink.table.catalog.AbstractCatalog; import org.apache.flink.table.catalog.CatalogBaseTable; @@ -30,7 +24,6 @@ import org.apache.flink.table.catalog.CatalogPartitionSpec; import org.apache.flink.table.catalog.CatalogTable; import org.apache.flink.table.catalog.ObjectPath; -import org.apache.flink.table.catalog.ResolvedCatalogTable; import org.apache.flink.table.catalog.exceptions.CatalogException; import org.apache.flink.table.catalog.exceptions.DatabaseAlreadyExistException; import org.apache.flink.table.catalog.exceptions.DatabaseNotEmptyException; @@ -56,14 +49,12 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.net.URI; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -72,11 +63,12 @@ /** * Lance Catalog implementation. - * - *

Implements Flink Catalog interface, supports managing Lance datasets as Flink tables. - * Supports local file system and S3 protocol object storage. - * + * + *

Implements Flink Catalog interface, supports managing Lance datasets as Flink tables. Supports + * local file system and S3 protocol object storage. + * *

Usage example (local path): + * *

{@code
  * CREATE CATALOG lance_catalog WITH (
  *     'type' = 'lance',
@@ -84,8 +76,9 @@
  *     'default-database' = 'default'
  * );
  * }
- * + * *

Usage example (S3 path): + * *

{@code
  * CREATE CATALOG lance_s3_catalog WITH (
  *     'type' = 'lance',
@@ -99,756 +92,790 @@
  */
 public class LanceCatalog extends AbstractCatalog {
 
-    private static final Logger LOG = LoggerFactory.getLogger(LanceCatalog.class);
-
-    public static final String DEFAULT_DATABASE = "default";
-
-    private final String warehouse;
-    private final Map storageOptions;
-    private final boolean isRemoteStorage;
-    private transient BufferAllocator allocator;
-    
-    // Cache known databases and tables for remote storage
-    private final Set knownDatabases = ConcurrentHashMap.newKeySet();
-    private final Set knownTables = ConcurrentHashMap.newKeySet();
-
-    /**
-     * Create LanceCatalog (local storage)
-     *
-     * @param name Catalog name
-     * @param defaultDatabase Default database name
-     * @param warehouse Warehouse path
-     */
-    public LanceCatalog(String name, String defaultDatabase, String warehouse) {
-        this(name, defaultDatabase, warehouse, Collections.emptyMap());
-    }
-
-    /**
-     * Create LanceCatalog (supports remote storage)
-     *
-     * @param name Catalog name
-     * @param defaultDatabase Default database name
-     * @param warehouse Warehouse path (local path or S3 URI)
-     * @param storageOptions Storage configuration options (e.g., S3 credentials)
-     */
-    public LanceCatalog(String name, String defaultDatabase, String warehouse, Map storageOptions) {
-        super(name, defaultDatabase);
-        this.warehouse = normalizeWarehousePath(warehouse);
-        this.storageOptions = storageOptions != null ? new HashMap<>(storageOptions) : Collections.emptyMap();
-        this.isRemoteStorage = isRemotePath(warehouse);
-    }
-
-    /**
-     * Check if path is remote storage path
-     */
-    private boolean isRemotePath(String path) {
-        if (path == null) {
-            return false;
-        }
-        String lowerPath = path.toLowerCase();
-        return lowerPath.startsWith("s3://") || 
-               lowerPath.startsWith("s3a://") || 
-               lowerPath.startsWith("gs://") || 
-               lowerPath.startsWith("az://") ||
-               lowerPath.startsWith("https://") ||
-               lowerPath.startsWith("http://");
-    }
-
-    /**
-     * Normalize warehouse path
-     */
-    private String normalizeWarehousePath(String path) {
-        if (path == null) {
-            return null;
-        }
-        // Remove trailing slashes
-        while (path.endsWith("/") && path.length() > 1) {
-            path = path.substring(0, path.length() - 1);
-        }
-        return path;
-    }
-
-    @Override
-    public void open() throws CatalogException {
-        LOG.info("Opening Lance Catalog: {}, warehouse path: {}, remote storage: {}", getName(), warehouse, isRemoteStorage);
-        
-        this.allocator = new RootAllocator(Long.MAX_VALUE);
-        
-        if (isRemoteStorage) {
-            // Remote storage: initialize default database record
-            knownDatabases.add(getDefaultDatabase());
-            LOG.info("Remote storage mode enabled, storage config count: {}", storageOptions.size());
-        } else {
-            // Local storage: ensure warehouse directory exists
-            Path warehousePath = Paths.get(warehouse);
-            if (!Files.exists(warehousePath)) {
-                try {
-                    Files.createDirectories(warehousePath);
-                } catch (IOException e) {
-                    throw new CatalogException("Cannot create warehouse directory: " + warehouse, e);
-                }
-            }
-            
-            // Ensure default database exists
-            Path defaultDbPath = warehousePath.resolve(getDefaultDatabase());
-            if (!Files.exists(defaultDbPath)) {
-                try {
-                    Files.createDirectories(defaultDbPath);
-                } catch (IOException e) {
-                    throw new CatalogException("Cannot create default database directory: " + defaultDbPath, e);
-                }
-            }
-        }
-    }
-
-    @Override
-    public void close() throws CatalogException {
-        LOG.info("Closing Lance Catalog: {}", getName());
-        
-        if (allocator != null) {
-            try {
-                allocator.close();
-            } catch (Exception e) {
-                LOG.warn("Failed to close allocator", e);
-            }
-            allocator = null;
-        }
-        
-        knownDatabases.clear();
-        knownTables.clear();
-    }
-
-    // ==================== Database Operations ====================
-
-    @Override
-    public List listDatabases() throws CatalogException {
-        if (isRemoteStorage) {
-            // Remote storage: return known database list
-            return new ArrayList<>(knownDatabases);
-        }
-        
+  private static final Logger LOG = LoggerFactory.getLogger(LanceCatalog.class);
+
+  public static final String DEFAULT_DATABASE = "default";
+
+  private final String warehouse;
+  private final Map storageOptions;
+  private final boolean isRemoteStorage;
+  private transient BufferAllocator allocator;
+
+  // Cache known databases and tables for remote storage
+  private final Set knownDatabases = ConcurrentHashMap.newKeySet();
+  private final Set knownTables = ConcurrentHashMap.newKeySet();
+
+  /**
+   * Create LanceCatalog (local storage)
+   *
+   * @param name Catalog name
+   * @param defaultDatabase Default database name
+   * @param warehouse Warehouse path
+   */
+  public LanceCatalog(String name, String defaultDatabase, String warehouse) {
+    this(name, defaultDatabase, warehouse, Collections.emptyMap());
+  }
+
+  /**
+   * Create LanceCatalog (supports remote storage)
+   *
+   * @param name Catalog name
+   * @param defaultDatabase Default database name
+   * @param warehouse Warehouse path (local path or S3 URI)
+   * @param storageOptions Storage configuration options (e.g., S3 credentials)
+   */
+  public LanceCatalog(
+      String name, String defaultDatabase, String warehouse, Map storageOptions) {
+    super(name, defaultDatabase);
+    this.warehouse = normalizeWarehousePath(warehouse);
+    this.storageOptions =
+        storageOptions != null ? new HashMap<>(storageOptions) : Collections.emptyMap();
+    this.isRemoteStorage = isRemotePath(warehouse);
+  }
+
+  /** Check if path is remote storage path */
+  private boolean isRemotePath(String path) {
+    if (path == null) {
+      return false;
+    }
+    String lowerPath = path.toLowerCase();
+    return lowerPath.startsWith("s3://")
+        || lowerPath.startsWith("s3a://")
+        || lowerPath.startsWith("gs://")
+        || lowerPath.startsWith("az://")
+        || lowerPath.startsWith("https://")
+        || lowerPath.startsWith("http://");
+  }
+
+  /** Normalize warehouse path */
+  private String normalizeWarehousePath(String path) {
+    if (path == null) {
+      return null;
+    }
+    // Remove trailing slashes
+    while (path.endsWith("/") && path.length() > 1) {
+      path = path.substring(0, path.length() - 1);
+    }
+    return path;
+  }
+
+  @Override
+  public void open() throws CatalogException {
+    LOG.info(
+        "Opening Lance Catalog: {}, warehouse path: {}, remote storage: {}",
+        getName(),
+        warehouse,
+        isRemoteStorage);
+
+    this.allocator = new RootAllocator(Long.MAX_VALUE);
+
+    if (isRemoteStorage) {
+      // Remote storage: initialize default database record
+      knownDatabases.add(getDefaultDatabase());
+      LOG.info("Remote storage mode enabled, storage config count: {}", storageOptions.size());
+    } else {
+      // Local storage: ensure warehouse directory exists
+      Path warehousePath = Paths.get(warehouse);
+      if (!Files.exists(warehousePath)) {
         try {
-            Path warehousePath = Paths.get(warehouse);
-            if (!Files.exists(warehousePath)) {
-                return Collections.emptyList();
-            }
-            
-            return Files.list(warehousePath)
-                    .filter(Files::isDirectory)
-                    .map(path -> path.getFileName().toString())
-                    .collect(Collectors.toList());
+          Files.createDirectories(warehousePath);
         } catch (IOException e) {
-            throw new CatalogException("Failed to list databases", e);
-        }
-    }
-
-    @Override
-    public CatalogDatabase getDatabase(String databaseName) throws DatabaseNotExistException, CatalogException {
-        if (!databaseExists(databaseName)) {
-            throw new DatabaseNotExistException(getName(), databaseName);
-        }
-        
-        return new CatalogDatabaseImpl(Collections.emptyMap(), "Lance Database: " + databaseName);
-    }
-
-    @Override
-    public boolean databaseExists(String databaseName) throws CatalogException {
-        if (isRemoteStorage) {
-            // Remote storage: check known databases or try listing tables to verify
-            if (knownDatabases.contains(databaseName)) {
-                return true;
-            }
-            // Try to confirm database exists by checking for tables
-            try {
-                String dbPath = getDatabasePath(databaseName);
-                // For remote storage, assume database always exists (actual table operations will verify)
-                return true;
-            } catch (Exception e) {
-                return false;
-            }
+          throw new CatalogException("Cannot create warehouse directory: " + warehouse, e);
         }
-        
-        Path dbPath = Paths.get(warehouse, databaseName);
-        return Files.exists(dbPath) && Files.isDirectory(dbPath);
-    }
+      }
 
-    @Override
-    public void createDatabase(String name, CatalogDatabase database, boolean ignoreIfExists)
-            throws DatabaseAlreadyExistException, CatalogException {
-        if (isRemoteStorage) {
-            // Remote storage: only record database name, actual directory created when creating table
-            if (knownDatabases.contains(name)) {
-                if (!ignoreIfExists) {
-                    throw new DatabaseAlreadyExistException(getName(), name);
-                }
-                return;
-            }
-            knownDatabases.add(name);
-            LOG.info("Registered remote database: {}", name);
-            return;
-        }
-        
-        if (databaseExists(name)) {
-            if (!ignoreIfExists) {
-                throw new DatabaseAlreadyExistException(getName(), name);
-            }
-            return;
-        }
-        
-        Path dbPath = Paths.get(warehouse, name);
+      // Ensure default database exists
+      Path defaultDbPath = warehousePath.resolve(getDefaultDatabase());
+      if (!Files.exists(defaultDbPath)) {
         try {
-            Files.createDirectories(dbPath);
-            LOG.info("Created database: {}", name);
+          Files.createDirectories(defaultDbPath);
         } catch (IOException e) {
-            throw new CatalogException("Failed to create database: " + name, e);
+          throw new CatalogException(
+              "Cannot create default database directory: " + defaultDbPath, e);
         }
+      }
     }
+  }
 
-    @Override
-    public void dropDatabase(String name, boolean ignoreIfNotExists, boolean cascade)
-            throws DatabaseNotExistException, DatabaseNotEmptyException, CatalogException {
-        if (isRemoteStorage) {
-            // Remote storage: remove database record
-            if (!knownDatabases.contains(name)) {
-                if (!ignoreIfNotExists) {
-                    throw new DatabaseNotExistException(getName(), name);
-                }
-                return;
-            }
-            
-            // Check if has tables
-            List tables = listTables(name);
-            if (!tables.isEmpty() && !cascade) {
-                throw new DatabaseNotEmptyException(getName(), name);
-            }
-            
-            // If cascade, delete all tables
-            if (cascade) {
-                for (String table : tables) {
-                    try {
-                        dropTable(new ObjectPath(name, table), true);
-                    } catch (TableNotExistException e) {
-                        // Ignore
-                    }
-                }
-            }
-            
-            knownDatabases.remove(name);
-            LOG.info("Removed remote database record: {}", name);
-            return;
-        }
-        
-        if (!databaseExists(name)) {
-            if (!ignoreIfNotExists) {
-                throw new DatabaseNotExistException(getName(), name);
-            }
-            return;
-        }
-        
-        Path dbPath = Paths.get(warehouse, name);
-        try {
-            List tables = listTables(name);
-            if (!tables.isEmpty() && !cascade) {
-                throw new DatabaseNotEmptyException(getName(), name);
-            }
-            
-            // Delete database directory
-            deleteDirectory(dbPath);
-            LOG.info("Deleted database: {}", name);
-        } catch (IOException e) {
-            throw new CatalogException("Failed to delete database: " + name, e);
-        }
-    }
-
-    @Override
-    public void alterDatabase(String name, CatalogDatabase newDatabase, boolean ignoreIfNotExists)
-            throws DatabaseNotExistException, CatalogException {
-        if (!databaseExists(name)) {
-            if (!ignoreIfNotExists) {
-                throw new DatabaseNotExistException(getName(), name);
-            }
-            return;
-        }
-        // Lance database does not support modifying properties
-        LOG.warn("Lance Catalog does not support modifying database properties");
-    }
+  @Override
+  public void close() throws CatalogException {
+    LOG.info("Closing Lance Catalog: {}", getName());
 
-    // ==================== Table Operations ====================
-
-    @Override
-    public List listTables(String databaseName) throws DatabaseNotExistException, CatalogException {
-        if (!databaseExists(databaseName)) {
-            throw new DatabaseNotExistException(getName(), databaseName);
-        }
-        
-        if (isRemoteStorage) {
-            // Remote storage: return known table list
-            String prefix = databaseName + "/";
-            return knownTables.stream()
-                    .filter(t -> t.startsWith(prefix))
-                    .map(t -> t.substring(prefix.length()))
-                    .collect(Collectors.toList());
-        }
-        
-        try {
-            Path dbPath = Paths.get(warehouse, databaseName);
-            return Files.list(dbPath)
-                    .filter(Files::isDirectory)
-                    .filter(path -> Files.exists(path.resolve("_versions"))) // Lance dataset identifier
-                    .map(path -> path.getFileName().toString())
-                    .collect(Collectors.toList());
-        } catch (IOException e) {
-            throw new CatalogException("Failed to list tables", e);
-        }
+    if (allocator != null) {
+      try {
+        allocator.close();
+      } catch (Exception e) {
+        LOG.warn("Failed to close allocator", e);
+      }
+      allocator = null;
     }
 
-    @Override
-    public List listViews(String databaseName) throws DatabaseNotExistException, CatalogException {
-        // Lance does not support views
-        return Collections.emptyList();
-    }
+    knownDatabases.clear();
+    knownTables.clear();
+  }
 
-    @Override
-    public CatalogBaseTable getTable(ObjectPath tablePath) throws TableNotExistException, CatalogException {
-        if (!tableExists(tablePath)) {
-            throw new TableNotExistException(getName(), tablePath);
-        }
-        
-        String datasetPath = getDatasetPath(tablePath);
-        
-        try {
-            // For remote storage, configure S3 credentials via environment variables
-            if (isRemoteStorage) {
-                configureStorageEnvironment();
-            }
-            Dataset dataset = Dataset.open(datasetPath, allocator);
-            
-            try {
-                // Infer Flink Schema from Lance Schema
-                org.apache.arrow.vector.types.pojo.Schema arrowSchema = dataset.getSchema();
-                RowType rowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
-                
-                // Build CatalogTable
-                Schema.Builder schemaBuilder = Schema.newBuilder();
-                for (RowType.RowField field : rowType.getFields()) {
-                    DataType dataType = LanceTypeConverter.toDataType(field.getType());
-                    schemaBuilder.column(field.getName(), dataType);
-                }
-                
-                Map options = new HashMap<>();
-                options.put("connector", LanceDynamicTableFactory.IDENTIFIER);
-                options.put("path", datasetPath);
-                
-                // If remote storage, add storage config to table options
-                if (isRemoteStorage) {
-                    options.putAll(getStorageOptionsForTable());
-                }
-                
-                return CatalogTable.of(
-                        schemaBuilder.build(),
-                        "Lance Table: " + tablePath.getFullName(),
-                        Collections.emptyList(),
-                        options
-                );
-            } finally {
-                dataset.close();
-            }
-        } catch (Exception e) {
-            throw new CatalogException("Failed to get table info: " + tablePath, e);
-        }
-    }
+  // ==================== Database Operations ====================
 
-    @Override
-    public boolean tableExists(ObjectPath tablePath) throws CatalogException {
-        if (!databaseExists(tablePath.getDatabaseName())) {
-            return false;
-        }
-        
-        String datasetPath = getDatasetPath(tablePath);
-        
-        if (isRemoteStorage) {
-            // Remote storage: check known tables or try opening dataset
-            String tableKey = tablePath.getDatabaseName() + "/" + tablePath.getObjectName();
-            if (knownTables.contains(tableKey)) {
-                return true;
-            }
-            
-            // Try to open dataset to verify existence
-            try {
-                configureStorageEnvironment();
-                Dataset dataset = Dataset.open(datasetPath, allocator);
-                dataset.close();
-                knownTables.add(tableKey);
-                return true;
-            } catch (Exception e) {
-                LOG.debug("Table does not exist or cannot be accessed: {}", datasetPath, e);
-                return false;
-            }
-        }
-        
-        Path path = Paths.get(datasetPath);
-        
-        // Check if valid Lance dataset
-        return Files.exists(path) && Files.isDirectory(path) && 
-               Files.exists(path.resolve("_versions"));
-    }
-
-    @Override
-    public void dropTable(ObjectPath tablePath, boolean ignoreIfNotExists)
-            throws TableNotExistException, CatalogException {
-        if (!tableExists(tablePath)) {
-            if (!ignoreIfNotExists) {
-                throw new TableNotExistException(getName(), tablePath);
-            }
-            return;
-        }
-        
-        String datasetPath = getDatasetPath(tablePath);
-        
-        if (isRemoteStorage) {
-            // Remote storage: Lance Java SDK does not directly support deleting remote datasets
-            // Only remove record here, actual deletion requires cloud storage API
-            String tableKey = tablePath.getDatabaseName() + "/" + tablePath.getObjectName();
-            knownTables.remove(tableKey);
-            LOG.warn("Remote storage mode, table record removed, but actual data needs manual deletion from storage: {}", datasetPath);
-            return;
-        }
-        
-        try {
-            deleteDirectory(Paths.get(datasetPath));
-            LOG.info("Deleted table: {}", tablePath);
-        } catch (IOException e) {
-            throw new CatalogException("Failed to delete table: " + tablePath, e);
-        }
+  @Override
+  public List listDatabases() throws CatalogException {
+    if (isRemoteStorage) {
+      // Remote storage: return known database list
+      return new ArrayList<>(knownDatabases);
     }
 
-    @Override
-    public void renameTable(ObjectPath tablePath, String newTableName, boolean ignoreIfNotExists)
-            throws TableNotExistException, TableAlreadyExistException, CatalogException {
-        if (!tableExists(tablePath)) {
-            if (!ignoreIfNotExists) {
-                throw new TableNotExistException(getName(), tablePath);
-            }
-            return;
-        }
-        
-        ObjectPath newTablePath = new ObjectPath(tablePath.getDatabaseName(), newTableName);
-        if (tableExists(newTablePath)) {
-            throw new TableAlreadyExistException(getName(), newTablePath);
-        }
-        
-        if (isRemoteStorage) {
-            // Remote storage: does not support renaming
-            throw new CatalogException("Remote storage mode does not support renaming tables");
-        }
-        
-        String oldPath = getDatasetPath(tablePath);
-        String newPath = getDatasetPath(newTablePath);
-        
-        try {
-            Files.move(Paths.get(oldPath), Paths.get(newPath));
-            LOG.info("Renamed table: {} -> {}", tablePath, newTablePath);
-        } catch (IOException e) {
-            throw new CatalogException("Failed to rename table: " + tablePath, e);
-        }
-    }
-
-    @Override
-    public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ignoreIfExists)
-            throws TableAlreadyExistException, DatabaseNotExistException, CatalogException {
-        if (!databaseExists(tablePath.getDatabaseName())) {
-            throw new DatabaseNotExistException(getName(), tablePath.getDatabaseName());
-        }
-        
-        if (tableExists(tablePath)) {
-            if (!ignoreIfExists) {
-                throw new TableAlreadyExistException(getName(), tablePath);
-            }
-            return;
-        }
-        
-        if (isRemoteStorage) {
-            // Remote storage: record table info, actual creation on write
-            String tableKey = tablePath.getDatabaseName() + "/" + tablePath.getObjectName();
-            knownTables.add(tableKey);
-        }
-        
-        // Actual table creation happens on first write
-        // Only record table metadata here
-        LOG.info("Registered table: {} (actual dataset will be created on write)", tablePath);
-    }
-
-    @Override
-    public void alterTable(ObjectPath tablePath, CatalogBaseTable newTable, boolean ignoreIfNotExists)
-            throws TableNotExistException, CatalogException {
-        if (!tableExists(tablePath)) {
-            if (!ignoreIfNotExists) {
-                throw new TableNotExistException(getName(), tablePath);
-            }
-            return;
-        }
-        
-        // Lance does not support modifying table structure
-        throw new CatalogException("Lance Catalog does not support altering table structure");
-    }
-
-    // ==================== Partition Operations (Lance does not support partitions) ====================
-
-    @Override
-    public List listPartitions(ObjectPath tablePath)
-            throws TableNotExistException, TableNotPartitionedException, CatalogException {
-        return Collections.emptyList();
-    }
-
-    @Override
-    public List listPartitions(ObjectPath tablePath, CatalogPartitionSpec partitionSpec)
-            throws TableNotExistException, TableNotPartitionedException, PartitionSpecInvalidException, CatalogException {
-        return Collections.emptyList();
-    }
-
-    @Override
-    public List listPartitionsByFilter(ObjectPath tablePath, List filters)
-            throws TableNotExistException, TableNotPartitionedException, CatalogException {
+    try {
+      Path warehousePath = Paths.get(warehouse);
+      if (!Files.exists(warehousePath)) {
         return Collections.emptyList();
-    }
-
-    @Override
-    public CatalogPartition getPartition(ObjectPath tablePath, CatalogPartitionSpec partitionSpec)
-            throws PartitionNotExistException, CatalogException {
-        throw new PartitionNotExistException(getName(), tablePath, partitionSpec);
-    }
-
-    @Override
-    public boolean partitionExists(ObjectPath tablePath, CatalogPartitionSpec partitionSpec)
-            throws CatalogException {
+      }
+
+      return Files.list(warehousePath)
+          .filter(Files::isDirectory)
+          .map(path -> path.getFileName().toString())
+          .collect(Collectors.toList());
+    } catch (IOException e) {
+      throw new CatalogException("Failed to list databases", e);
+    }
+  }
+
+  @Override
+  public CatalogDatabase getDatabase(String databaseName)
+      throws DatabaseNotExistException, CatalogException {
+    if (!databaseExists(databaseName)) {
+      throw new DatabaseNotExistException(getName(), databaseName);
+    }
+
+    return new CatalogDatabaseImpl(Collections.emptyMap(), "Lance Database: " + databaseName);
+  }
+
+  @Override
+  public boolean databaseExists(String databaseName) throws CatalogException {
+    if (isRemoteStorage) {
+      // Remote storage: check known databases or try listing tables to verify
+      if (knownDatabases.contains(databaseName)) {
+        return true;
+      }
+      // Try to confirm database exists by checking for tables
+      try {
+        String dbPath = getDatabasePath(databaseName);
+        // For remote storage, assume database always exists (actual table operations will verify)
+        return true;
+      } catch (Exception e) {
         return false;
-    }
-
-    @Override
-    public void createPartition(ObjectPath tablePath, CatalogPartitionSpec partitionSpec, CatalogPartition partition, boolean ignoreIfExists)
-            throws TableNotExistException, TableNotPartitionedException, PartitionSpecInvalidException, PartitionAlreadyExistsException, CatalogException {
-        throw new CatalogException("Lance Catalog does not support partition operations");
-    }
-
-    @Override
-    public void dropPartition(ObjectPath tablePath, CatalogPartitionSpec partitionSpec, boolean ignoreIfNotExists)
-            throws PartitionNotExistException, CatalogException {
-        throw new CatalogException("Lance Catalog does not support partition operations");
-    }
-
-    @Override
-    public void alterPartition(ObjectPath tablePath, CatalogPartitionSpec partitionSpec, CatalogPartition newPartition, boolean ignoreIfNotExists)
-            throws PartitionNotExistException, CatalogException {
-        throw new CatalogException("Lance Catalog does not support partition operations");
-    }
-
-    // ==================== Function Operations (Lance does not support UDFs) ====================
-
-    @Override
-    public List listFunctions(String dbName) throws DatabaseNotExistException, CatalogException {
-        return Collections.emptyList();
-    }
+      }
+    }
+
+    Path dbPath = Paths.get(warehouse, databaseName);
+    return Files.exists(dbPath) && Files.isDirectory(dbPath);
+  }
+
+  @Override
+  public void createDatabase(String name, CatalogDatabase database, boolean ignoreIfExists)
+      throws DatabaseAlreadyExistException, CatalogException {
+    if (isRemoteStorage) {
+      // Remote storage: only record database name, actual directory created when creating table
+      if (knownDatabases.contains(name)) {
+        if (!ignoreIfExists) {
+          throw new DatabaseAlreadyExistException(getName(), name);
+        }
+        return;
+      }
+      knownDatabases.add(name);
+      LOG.info("Registered remote database: {}", name);
+      return;
+    }
+
+    if (databaseExists(name)) {
+      if (!ignoreIfExists) {
+        throw new DatabaseAlreadyExistException(getName(), name);
+      }
+      return;
+    }
+
+    Path dbPath = Paths.get(warehouse, name);
+    try {
+      Files.createDirectories(dbPath);
+      LOG.info("Created database: {}", name);
+    } catch (IOException e) {
+      throw new CatalogException("Failed to create database: " + name, e);
+    }
+  }
+
+  @Override
+  public void dropDatabase(String name, boolean ignoreIfNotExists, boolean cascade)
+      throws DatabaseNotExistException, DatabaseNotEmptyException, CatalogException {
+    if (isRemoteStorage) {
+      // Remote storage: remove database record
+      if (!knownDatabases.contains(name)) {
+        if (!ignoreIfNotExists) {
+          throw new DatabaseNotExistException(getName(), name);
+        }
+        return;
+      }
+
+      // Check if has tables
+      List tables = listTables(name);
+      if (!tables.isEmpty() && !cascade) {
+        throw new DatabaseNotEmptyException(getName(), name);
+      }
+
+      // If cascade, delete all tables
+      if (cascade) {
+        for (String table : tables) {
+          try {
+            dropTable(new ObjectPath(name, table), true);
+          } catch (TableNotExistException e) {
+            // Ignore
+          }
+        }
+      }
+
+      knownDatabases.remove(name);
+      LOG.info("Removed remote database record: {}", name);
+      return;
+    }
+
+    if (!databaseExists(name)) {
+      if (!ignoreIfNotExists) {
+        throw new DatabaseNotExistException(getName(), name);
+      }
+      return;
+    }
+
+    Path dbPath = Paths.get(warehouse, name);
+    try {
+      List tables = listTables(name);
+      if (!tables.isEmpty() && !cascade) {
+        throw new DatabaseNotEmptyException(getName(), name);
+      }
+
+      // Delete database directory
+      deleteDirectory(dbPath);
+      LOG.info("Deleted database: {}", name);
+    } catch (IOException e) {
+      throw new CatalogException("Failed to delete database: " + name, e);
+    }
+  }
+
+  @Override
+  public void alterDatabase(String name, CatalogDatabase newDatabase, boolean ignoreIfNotExists)
+      throws DatabaseNotExistException, CatalogException {
+    if (!databaseExists(name)) {
+      if (!ignoreIfNotExists) {
+        throw new DatabaseNotExistException(getName(), name);
+      }
+      return;
+    }
+    // Lance database does not support modifying properties
+    LOG.warn("Lance Catalog does not support modifying database properties");
+  }
+
+  // ==================== Table Operations ====================
+
+  @Override
+  public List listTables(String databaseName)
+      throws DatabaseNotExistException, CatalogException {
+    if (!databaseExists(databaseName)) {
+      throw new DatabaseNotExistException(getName(), databaseName);
+    }
+
+    if (isRemoteStorage) {
+      // Remote storage: return known table list
+      String prefix = databaseName + "/";
+      return knownTables.stream()
+          .filter(t -> t.startsWith(prefix))
+          .map(t -> t.substring(prefix.length()))
+          .collect(Collectors.toList());
+    }
+
+    try {
+      Path dbPath = Paths.get(warehouse, databaseName);
+      return Files.list(dbPath)
+          .filter(Files::isDirectory)
+          .filter(path -> Files.exists(path.resolve("_versions"))) // Lance dataset identifier
+          .map(path -> path.getFileName().toString())
+          .collect(Collectors.toList());
+    } catch (IOException e) {
+      throw new CatalogException("Failed to list tables", e);
+    }
+  }
+
+  @Override
+  public List listViews(String databaseName)
+      throws DatabaseNotExistException, CatalogException {
+    // Lance does not support views
+    return Collections.emptyList();
+  }
+
+  @Override
+  public CatalogBaseTable getTable(ObjectPath tablePath)
+      throws TableNotExistException, CatalogException {
+    if (!tableExists(tablePath)) {
+      throw new TableNotExistException(getName(), tablePath);
+    }
+
+    String datasetPath = getDatasetPath(tablePath);
+
+    try {
+      // For remote storage, configure S3 credentials via environment variables
+      if (isRemoteStorage) {
+        configureStorageEnvironment();
+      }
+      Dataset dataset = Dataset.open(datasetPath, allocator);
+
+      try {
+        // Infer Flink Schema from Lance Schema
+        org.apache.arrow.vector.types.pojo.Schema arrowSchema = dataset.getSchema();
+        RowType rowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
+
+        // Build CatalogTable
+        Schema.Builder schemaBuilder = Schema.newBuilder();
+        for (RowType.RowField field : rowType.getFields()) {
+          DataType dataType = LanceTypeConverter.toDataType(field.getType());
+          schemaBuilder.column(field.getName(), dataType);
+        }
 
-    @Override
-    public CatalogFunction getFunction(ObjectPath functionPath) throws FunctionNotExistException, CatalogException {
-        throw new FunctionNotExistException(getName(), functionPath);
-    }
+        Map options = new HashMap<>();
+        options.put("connector", LanceDynamicTableFactory.IDENTIFIER);
+        options.put("path", datasetPath);
 
-    @Override
-    public boolean functionExists(ObjectPath functionPath) throws CatalogException {
+        // If remote storage, add storage config to table options
+        if (isRemoteStorage) {
+          options.putAll(getStorageOptionsForTable());
+        }
+
+        return CatalogTable.of(
+            schemaBuilder.build(),
+            "Lance Table: " + tablePath.getFullName(),
+            Collections.emptyList(),
+            options);
+      } finally {
+        dataset.close();
+      }
+    } catch (Exception e) {
+      throw new CatalogException("Failed to get table info: " + tablePath, e);
+    }
+  }
+
+  @Override
+  public boolean tableExists(ObjectPath tablePath) throws CatalogException {
+    if (!databaseExists(tablePath.getDatabaseName())) {
+      return false;
+    }
+
+    String datasetPath = getDatasetPath(tablePath);
+
+    if (isRemoteStorage) {
+      // Remote storage: check known tables or try opening dataset
+      String tableKey = tablePath.getDatabaseName() + "/" + tablePath.getObjectName();
+      if (knownTables.contains(tableKey)) {
+        return true;
+      }
+
+      // Try to open dataset to verify existence
+      try {
+        configureStorageEnvironment();
+        Dataset dataset = Dataset.open(datasetPath, allocator);
+        dataset.close();
+        knownTables.add(tableKey);
+        return true;
+      } catch (Exception e) {
+        LOG.debug("Table does not exist or cannot be accessed: {}", datasetPath, e);
         return false;
+      }
     }
 
-    @Override
-    public void createFunction(ObjectPath functionPath, CatalogFunction function, boolean ignoreIfExists)
-            throws FunctionAlreadyExistException, DatabaseNotExistException, CatalogException {
-        throw new CatalogException("Lance Catalog does not support user-defined functions");
-    }
-
-    @Override
-    public void alterFunction(ObjectPath functionPath, CatalogFunction newFunction, boolean ignoreIfNotExists)
-            throws FunctionNotExistException, CatalogException {
-        throw new CatalogException("Lance Catalog does not support user-defined functions");
-    }
-
-    @Override
-    public void dropFunction(ObjectPath functionPath, boolean ignoreIfNotExists)
-            throws FunctionNotExistException, CatalogException {
-        throw new CatalogException("Lance Catalog does not support user-defined functions");
-    }
-
-    // ==================== Statistics Operations ====================
-
-    @Override
-    public CatalogTableStatistics getTableStatistics(ObjectPath tablePath)
-            throws TableNotExistException, CatalogException {
-        return CatalogTableStatistics.UNKNOWN;
-    }
-
-    @Override
-    public CatalogColumnStatistics getTableColumnStatistics(ObjectPath tablePath)
-            throws TableNotExistException, CatalogException {
-        return CatalogColumnStatistics.UNKNOWN;
-    }
-
-    @Override
-    public CatalogTableStatistics getPartitionStatistics(ObjectPath tablePath, CatalogPartitionSpec partitionSpec)
-            throws PartitionNotExistException, CatalogException {
-        return CatalogTableStatistics.UNKNOWN;
-    }
-
-    @Override
-    public CatalogColumnStatistics getPartitionColumnStatistics(ObjectPath tablePath, CatalogPartitionSpec partitionSpec)
-            throws PartitionNotExistException, CatalogException {
-        return CatalogColumnStatistics.UNKNOWN;
-    }
+    Path path = Paths.get(datasetPath);
 
-    @Override
-    public void alterTableStatistics(ObjectPath tablePath, CatalogTableStatistics tableStatistics, boolean ignoreIfNotExists)
-            throws TableNotExistException, CatalogException {
-        // Not supported
-    }
-
-    @Override
-    public void alterTableColumnStatistics(ObjectPath tablePath, CatalogColumnStatistics columnStatistics, boolean ignoreIfNotExists)
-            throws TableNotExistException, CatalogException {
-        // Not supported
-    }
-
-    @Override
-    public void alterPartitionStatistics(ObjectPath tablePath, CatalogPartitionSpec partitionSpec, CatalogTableStatistics partitionStatistics, boolean ignoreIfNotExists)
-            throws PartitionNotExistException, CatalogException {
-        // Not supported
-    }
+    // Check if valid Lance dataset
+    return Files.exists(path) && Files.isDirectory(path) && Files.exists(path.resolve("_versions"));
+  }
 
-    @Override
-    public void alterPartitionColumnStatistics(ObjectPath tablePath, CatalogPartitionSpec partitionSpec, CatalogColumnStatistics columnStatistics, boolean ignoreIfNotExists)
-            throws PartitionNotExistException, CatalogException {
-        // Not supported
-    }
-
-    // ==================== Utility Methods ====================
-
-    /**
-     * Configure storage environment variables (for S3 and other remote storage)
-     * 
-     * 

Lance configures S3 credentials via environment variables: - *

    - *
  • AWS_ACCESS_KEY_ID - AWS access key ID
  • - *
  • AWS_SECRET_ACCESS_KEY - AWS secret access key
  • - *
  • AWS_DEFAULT_REGION - AWS region
  • - *
  • AWS_ENDPOINT - Custom endpoint URL (for S3-compatible storage)
  • - *
- */ - private void configureStorageEnvironment() { - if (!isRemoteStorage || storageOptions.isEmpty()) { - return; - } - - // Set environment variables for Lance SDK object_store configuration - // Note: Since Java cannot directly modify environment variables, system properties are used as fallback - // Lance's Rust backend will read these environment variables - - if (storageOptions.containsKey("aws_access_key_id")) { - System.setProperty("AWS_ACCESS_KEY_ID", storageOptions.get("aws_access_key_id")); - } - if (storageOptions.containsKey("aws_secret_access_key")) { - System.setProperty("AWS_SECRET_ACCESS_KEY", storageOptions.get("aws_secret_access_key")); - } - if (storageOptions.containsKey("aws_region")) { - System.setProperty("AWS_DEFAULT_REGION", storageOptions.get("aws_region")); - } - if (storageOptions.containsKey("aws_endpoint")) { - System.setProperty("AWS_ENDPOINT", storageOptions.get("aws_endpoint")); - } - if (storageOptions.containsKey("aws_virtual_hosted_style_request")) { - System.setProperty("AWS_VIRTUAL_HOSTED_STYLE_REQUEST", - storageOptions.get("aws_virtual_hosted_style_request")); - } - if (storageOptions.containsKey("allow_http")) { - System.setProperty("AWS_ALLOW_HTTP", storageOptions.get("allow_http")); - } - - LOG.debug("Configured remote storage environment variables"); + @Override + public void dropTable(ObjectPath tablePath, boolean ignoreIfNotExists) + throws TableNotExistException, CatalogException { + if (!tableExists(tablePath)) { + if (!ignoreIfNotExists) { + throw new TableNotExistException(getName(), tablePath); + } + return; } - /** - * Get database path - */ - private String getDatabasePath(String databaseName) { - if (isRemoteStorage) { - return warehouse + "/" + databaseName; - } - return Paths.get(warehouse, databaseName).toString(); - } + String datasetPath = getDatasetPath(tablePath); - /** - * Get dataset path - */ - private String getDatasetPath(ObjectPath tablePath) { - if (isRemoteStorage) { - return warehouse + "/" + tablePath.getDatabaseName() + "/" + tablePath.getObjectName(); - } - return Paths.get(warehouse, tablePath.getDatabaseName(), tablePath.getObjectName()).toString(); + if (isRemoteStorage) { + // Remote storage: Lance Java SDK does not directly support deleting remote datasets + // Only remove record here, actual deletion requires cloud storage API + String tableKey = tablePath.getDatabaseName() + "/" + tablePath.getObjectName(); + knownTables.remove(tableKey); + LOG.warn( + "Remote storage mode, table record removed, but actual data needs" + + " manual deletion from storage: {}", + datasetPath); + return; } - /** - * Get storage options for table configuration - */ - private Map getStorageOptionsForTable() { - Map options = new HashMap<>(); - - // Convert storage options to table config format - if (storageOptions.containsKey("aws_access_key_id")) { - options.put("s3-access-key", storageOptions.get("aws_access_key_id")); - } - if (storageOptions.containsKey("aws_secret_access_key")) { - options.put("s3-secret-key", storageOptions.get("aws_secret_access_key")); - } - if (storageOptions.containsKey("aws_region")) { - options.put("s3-region", storageOptions.get("aws_region")); - } - if (storageOptions.containsKey("aws_endpoint")) { - options.put("s3-endpoint", storageOptions.get("aws_endpoint")); - } - - return options; + try { + deleteDirectory(Paths.get(datasetPath)); + LOG.info("Deleted table: {}", tablePath); + } catch (IOException e) { + throw new CatalogException("Failed to delete table: " + tablePath, e); } - - /** - * Recursively delete directory - */ - private void deleteDirectory(Path path) throws IOException { - if (Files.isDirectory(path)) { - Files.list(path).forEach(child -> { + } + + @Override + public void renameTable(ObjectPath tablePath, String newTableName, boolean ignoreIfNotExists) + throws TableNotExistException, TableAlreadyExistException, CatalogException { + if (!tableExists(tablePath)) { + if (!ignoreIfNotExists) { + throw new TableNotExistException(getName(), tablePath); + } + return; + } + + ObjectPath newTablePath = new ObjectPath(tablePath.getDatabaseName(), newTableName); + if (tableExists(newTablePath)) { + throw new TableAlreadyExistException(getName(), newTablePath); + } + + if (isRemoteStorage) { + // Remote storage: does not support renaming + throw new CatalogException("Remote storage mode does not support renaming tables"); + } + + String oldPath = getDatasetPath(tablePath); + String newPath = getDatasetPath(newTablePath); + + try { + Files.move(Paths.get(oldPath), Paths.get(newPath)); + LOG.info("Renamed table: {} -> {}", tablePath, newTablePath); + } catch (IOException e) { + throw new CatalogException("Failed to rename table: " + tablePath, e); + } + } + + @Override + public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ignoreIfExists) + throws TableAlreadyExistException, DatabaseNotExistException, CatalogException { + if (!databaseExists(tablePath.getDatabaseName())) { + throw new DatabaseNotExistException(getName(), tablePath.getDatabaseName()); + } + + if (tableExists(tablePath)) { + if (!ignoreIfExists) { + throw new TableAlreadyExistException(getName(), tablePath); + } + return; + } + + if (isRemoteStorage) { + // Remote storage: record table info, actual creation on write + String tableKey = tablePath.getDatabaseName() + "/" + tablePath.getObjectName(); + knownTables.add(tableKey); + } + + // Actual table creation happens on first write + // Only record table metadata here + LOG.info("Registered table: {} (actual dataset will be created on write)", tablePath); + } + + @Override + public void alterTable(ObjectPath tablePath, CatalogBaseTable newTable, boolean ignoreIfNotExists) + throws TableNotExistException, CatalogException { + if (!tableExists(tablePath)) { + if (!ignoreIfNotExists) { + throw new TableNotExistException(getName(), tablePath); + } + return; + } + + // Lance does not support modifying table structure + throw new CatalogException("Lance Catalog does not support altering table structure"); + } + + // ==================== Partition Operations (Lance does not support partitions) + // ==================== + + @Override + public List listPartitions(ObjectPath tablePath) + throws TableNotExistException, TableNotPartitionedException, CatalogException { + return Collections.emptyList(); + } + + @Override + public List listPartitions( + ObjectPath tablePath, CatalogPartitionSpec partitionSpec) + throws TableNotExistException, + TableNotPartitionedException, + PartitionSpecInvalidException, + CatalogException { + return Collections.emptyList(); + } + + @Override + public List listPartitionsByFilter( + ObjectPath tablePath, List filters) + throws TableNotExistException, TableNotPartitionedException, CatalogException { + return Collections.emptyList(); + } + + @Override + public CatalogPartition getPartition(ObjectPath tablePath, CatalogPartitionSpec partitionSpec) + throws PartitionNotExistException, CatalogException { + throw new PartitionNotExistException(getName(), tablePath, partitionSpec); + } + + @Override + public boolean partitionExists(ObjectPath tablePath, CatalogPartitionSpec partitionSpec) + throws CatalogException { + return false; + } + + @Override + public void createPartition( + ObjectPath tablePath, + CatalogPartitionSpec partitionSpec, + CatalogPartition partition, + boolean ignoreIfExists) + throws TableNotExistException, + TableNotPartitionedException, + PartitionSpecInvalidException, + PartitionAlreadyExistsException, + CatalogException { + throw new CatalogException("Lance Catalog does not support partition operations"); + } + + @Override + public void dropPartition( + ObjectPath tablePath, CatalogPartitionSpec partitionSpec, boolean ignoreIfNotExists) + throws PartitionNotExistException, CatalogException { + throw new CatalogException("Lance Catalog does not support partition operations"); + } + + @Override + public void alterPartition( + ObjectPath tablePath, + CatalogPartitionSpec partitionSpec, + CatalogPartition newPartition, + boolean ignoreIfNotExists) + throws PartitionNotExistException, CatalogException { + throw new CatalogException("Lance Catalog does not support partition operations"); + } + + // ==================== Function Operations (Lance does not support UDFs) ==================== + + @Override + public List listFunctions(String dbName) + throws DatabaseNotExistException, CatalogException { + return Collections.emptyList(); + } + + @Override + public CatalogFunction getFunction(ObjectPath functionPath) + throws FunctionNotExistException, CatalogException { + throw new FunctionNotExistException(getName(), functionPath); + } + + @Override + public boolean functionExists(ObjectPath functionPath) throws CatalogException { + return false; + } + + @Override + public void createFunction( + ObjectPath functionPath, CatalogFunction function, boolean ignoreIfExists) + throws FunctionAlreadyExistException, DatabaseNotExistException, CatalogException { + throw new CatalogException("Lance Catalog does not support user-defined functions"); + } + + @Override + public void alterFunction( + ObjectPath functionPath, CatalogFunction newFunction, boolean ignoreIfNotExists) + throws FunctionNotExistException, CatalogException { + throw new CatalogException("Lance Catalog does not support user-defined functions"); + } + + @Override + public void dropFunction(ObjectPath functionPath, boolean ignoreIfNotExists) + throws FunctionNotExistException, CatalogException { + throw new CatalogException("Lance Catalog does not support user-defined functions"); + } + + // ==================== Statistics Operations ==================== + + @Override + public CatalogTableStatistics getTableStatistics(ObjectPath tablePath) + throws TableNotExistException, CatalogException { + return CatalogTableStatistics.UNKNOWN; + } + + @Override + public CatalogColumnStatistics getTableColumnStatistics(ObjectPath tablePath) + throws TableNotExistException, CatalogException { + return CatalogColumnStatistics.UNKNOWN; + } + + @Override + public CatalogTableStatistics getPartitionStatistics( + ObjectPath tablePath, CatalogPartitionSpec partitionSpec) + throws PartitionNotExistException, CatalogException { + return CatalogTableStatistics.UNKNOWN; + } + + @Override + public CatalogColumnStatistics getPartitionColumnStatistics( + ObjectPath tablePath, CatalogPartitionSpec partitionSpec) + throws PartitionNotExistException, CatalogException { + return CatalogColumnStatistics.UNKNOWN; + } + + @Override + public void alterTableStatistics( + ObjectPath tablePath, CatalogTableStatistics tableStatistics, boolean ignoreIfNotExists) + throws TableNotExistException, CatalogException { + // Not supported + } + + @Override + public void alterTableColumnStatistics( + ObjectPath tablePath, CatalogColumnStatistics columnStatistics, boolean ignoreIfNotExists) + throws TableNotExistException, CatalogException { + // Not supported + } + + @Override + public void alterPartitionStatistics( + ObjectPath tablePath, + CatalogPartitionSpec partitionSpec, + CatalogTableStatistics partitionStatistics, + boolean ignoreIfNotExists) + throws PartitionNotExistException, CatalogException { + // Not supported + } + + @Override + public void alterPartitionColumnStatistics( + ObjectPath tablePath, + CatalogPartitionSpec partitionSpec, + CatalogColumnStatistics columnStatistics, + boolean ignoreIfNotExists) + throws PartitionNotExistException, CatalogException { + // Not supported + } + + // ==================== Utility Methods ==================== + + /** + * Configure storage environment variables (for S3 and other remote storage) + * + *

Lance configures S3 credentials via environment variables: + * + *

    + *
  • AWS_ACCESS_KEY_ID - AWS access key ID + *
  • AWS_SECRET_ACCESS_KEY - AWS secret access key + *
  • AWS_DEFAULT_REGION - AWS region + *
  • AWS_ENDPOINT - Custom endpoint URL (for S3-compatible storage) + *
+ */ + private void configureStorageEnvironment() { + if (!isRemoteStorage || storageOptions.isEmpty()) { + return; + } + + // Set environment variables for Lance SDK object_store configuration + // Note: Since Java cannot directly modify environment variables, system properties are used as + // fallback + // Lance's Rust backend will read these environment variables + + if (storageOptions.containsKey("aws_access_key_id")) { + System.setProperty("AWS_ACCESS_KEY_ID", storageOptions.get("aws_access_key_id")); + } + if (storageOptions.containsKey("aws_secret_access_key")) { + System.setProperty("AWS_SECRET_ACCESS_KEY", storageOptions.get("aws_secret_access_key")); + } + if (storageOptions.containsKey("aws_region")) { + System.setProperty("AWS_DEFAULT_REGION", storageOptions.get("aws_region")); + } + if (storageOptions.containsKey("aws_endpoint")) { + System.setProperty("AWS_ENDPOINT", storageOptions.get("aws_endpoint")); + } + if (storageOptions.containsKey("aws_virtual_hosted_style_request")) { + System.setProperty( + "AWS_VIRTUAL_HOSTED_STYLE_REQUEST", + storageOptions.get("aws_virtual_hosted_style_request")); + } + if (storageOptions.containsKey("allow_http")) { + System.setProperty("AWS_ALLOW_HTTP", storageOptions.get("allow_http")); + } + + LOG.debug("Configured remote storage environment variables"); + } + + /** Get database path */ + private String getDatabasePath(String databaseName) { + if (isRemoteStorage) { + return warehouse + "/" + databaseName; + } + return Paths.get(warehouse, databaseName).toString(); + } + + /** Get dataset path */ + private String getDatasetPath(ObjectPath tablePath) { + if (isRemoteStorage) { + return warehouse + "/" + tablePath.getDatabaseName() + "/" + tablePath.getObjectName(); + } + return Paths.get(warehouse, tablePath.getDatabaseName(), tablePath.getObjectName()).toString(); + } + + /** Get storage options for table configuration */ + private Map getStorageOptionsForTable() { + Map options = new HashMap<>(); + + // Convert storage options to table config format + if (storageOptions.containsKey("aws_access_key_id")) { + options.put("s3-access-key", storageOptions.get("aws_access_key_id")); + } + if (storageOptions.containsKey("aws_secret_access_key")) { + options.put("s3-secret-key", storageOptions.get("aws_secret_access_key")); + } + if (storageOptions.containsKey("aws_region")) { + options.put("s3-region", storageOptions.get("aws_region")); + } + if (storageOptions.containsKey("aws_endpoint")) { + options.put("s3-endpoint", storageOptions.get("aws_endpoint")); + } + + return options; + } + + /** Recursively delete directory */ + private void deleteDirectory(Path path) throws IOException { + if (Files.isDirectory(path)) { + Files.list(path) + .forEach( + child -> { try { - deleteDirectory(child); + deleteDirectory(child); } catch (IOException e) { - LOG.warn("Failed to delete file: {}", child, e); + LOG.warn("Failed to delete file: {}", child, e); } - }); - } - Files.deleteIfExists(path); + }); } + Files.deleteIfExists(path); + } - /** - * Get warehouse path - */ - public String getWarehouse() { - return warehouse; - } + /** Get warehouse path */ + public String getWarehouse() { + return warehouse; + } - /** - * Get storage configuration options - */ - public Map getStorageOptions() { - return Collections.unmodifiableMap(storageOptions); - } + /** Get storage configuration options */ + public Map getStorageOptions() { + return Collections.unmodifiableMap(storageOptions); + } - /** - * Whether is remote storage - */ - public boolean isRemoteStorage() { - return isRemoteStorage; - } + /** Whether is remote storage */ + public boolean isRemoteStorage() { + return isRemoteStorage; + } } diff --git a/src/main/java/org/apache/flink/connector/lance/table/LanceCatalogFactory.java b/src/main/java/org/apache/flink/connector/lance/table/LanceCatalogFactory.java index 76c8436..64d8ae9 100644 --- a/src/main/java/org/apache/flink/connector/lance/table/LanceCatalogFactory.java +++ b/src/main/java/org/apache/flink/connector/lance/table/LanceCatalogFactory.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.table; import org.apache.flink.configuration.ConfigOption; @@ -31,10 +26,11 @@ /** * Lance Catalog factory. - * + * *

Used to create LanceCatalog via SQL DDL. - * + * *

Usage example (local path): + * *

{@code
  * CREATE CATALOG lance_catalog WITH (
  *     'type' = 'lance',
@@ -42,8 +38,9 @@
  *     'default-database' = 'default'
  * );
  * }
- * + * *

Usage example (S3 path): + * *

{@code
  * CREATE CATALOG lance_s3_catalog WITH (
  *     'type' = 'lance',
@@ -58,123 +55,124 @@
  */
 public class LanceCatalogFactory implements CatalogFactory {
 
-    public static final String IDENTIFIER = "lance";
-
-    public static final ConfigOption WAREHOUSE = ConfigOptions
-            .key("warehouse")
-            .stringType()
-            .noDefaultValue()
-            .withDescription("Lance data warehouse path, supports local path or S3 path (e.g., s3://bucket/path)");
-
-    public static final ConfigOption DEFAULT_DATABASE = ConfigOptions
-            .key("default-database")
-            .stringType()
-            .defaultValue(LanceCatalog.DEFAULT_DATABASE)
-            .withDescription("Default database name");
-
-    // ==================== S3 Configuration Options ====================
-    
-    public static final ConfigOption S3_ACCESS_KEY = ConfigOptions
-            .key("s3-access-key")
-            .stringType()
-            .noDefaultValue()
-            .withDescription("S3 Access Key ID");
-
-    public static final ConfigOption S3_SECRET_KEY = ConfigOptions
-            .key("s3-secret-key")
-            .stringType()
-            .noDefaultValue()
-            .withDescription("S3 Secret Access Key");
-
-    public static final ConfigOption S3_REGION = ConfigOptions
-            .key("s3-region")
-            .stringType()
-            .noDefaultValue()
-            .withDescription("S3 Region (e.g., us-east-1)");
-
-    public static final ConfigOption S3_ENDPOINT = ConfigOptions
-            .key("s3-endpoint")
-            .stringType()
-            .noDefaultValue()
-            .withDescription("S3 Endpoint URL (for S3-compatible object storage like MinIO)");
-
-    public static final ConfigOption S3_VIRTUAL_HOSTED_STYLE = ConfigOptions
-            .key("s3-virtual-hosted-style")
-            .booleanType()
-            .defaultValue(true)
-            .withDescription("Whether to use virtual hosted style URL (default true)");
-
-    public static final ConfigOption S3_ALLOW_HTTP = ConfigOptions
-            .key("s3-allow-http")
-            .booleanType()
-            .defaultValue(false)
-            .withDescription("Whether to allow HTTP connections (default false, HTTPS only)");
-
-    @Override
-    public String factoryIdentifier() {
-        return IDENTIFIER;
+  public static final String IDENTIFIER = "lance";
+
+  public static final ConfigOption WAREHOUSE =
+      ConfigOptions.key("warehouse")
+          .stringType()
+          .noDefaultValue()
+          .withDescription(
+              "Lance data warehouse path, supports local path or S3 path (e.g., s3://bucket/path)");
+
+  public static final ConfigOption DEFAULT_DATABASE =
+      ConfigOptions.key("default-database")
+          .stringType()
+          .defaultValue(LanceCatalog.DEFAULT_DATABASE)
+          .withDescription("Default database name");
+
+  // ==================== S3 Configuration Options ====================
+
+  public static final ConfigOption S3_ACCESS_KEY =
+      ConfigOptions.key("s3-access-key")
+          .stringType()
+          .noDefaultValue()
+          .withDescription("S3 Access Key ID");
+
+  public static final ConfigOption S3_SECRET_KEY =
+      ConfigOptions.key("s3-secret-key")
+          .stringType()
+          .noDefaultValue()
+          .withDescription("S3 Secret Access Key");
+
+  public static final ConfigOption S3_REGION =
+      ConfigOptions.key("s3-region")
+          .stringType()
+          .noDefaultValue()
+          .withDescription("S3 Region (e.g., us-east-1)");
+
+  public static final ConfigOption S3_ENDPOINT =
+      ConfigOptions.key("s3-endpoint")
+          .stringType()
+          .noDefaultValue()
+          .withDescription("S3 Endpoint URL (for S3-compatible object storage like MinIO)");
+
+  public static final ConfigOption S3_VIRTUAL_HOSTED_STYLE =
+      ConfigOptions.key("s3-virtual-hosted-style")
+          .booleanType()
+          .defaultValue(true)
+          .withDescription("Whether to use virtual hosted style URL (default true)");
+
+  public static final ConfigOption S3_ALLOW_HTTP =
+      ConfigOptions.key("s3-allow-http")
+          .booleanType()
+          .defaultValue(false)
+          .withDescription("Whether to allow HTTP connections (default false, HTTPS only)");
+
+  @Override
+  public String factoryIdentifier() {
+    return IDENTIFIER;
+  }
+
+  @Override
+  public Set> requiredOptions() {
+    Set> options = new HashSet<>();
+    options.add(WAREHOUSE);
+    return options;
+  }
+
+  @Override
+  public Set> optionalOptions() {
+    Set> options = new HashSet<>();
+    options.add(DEFAULT_DATABASE);
+    // S3 related options
+    options.add(S3_ACCESS_KEY);
+    options.add(S3_SECRET_KEY);
+    options.add(S3_REGION);
+    options.add(S3_ENDPOINT);
+    options.add(S3_VIRTUAL_HOSTED_STYLE);
+    options.add(S3_ALLOW_HTTP);
+    return options;
+  }
+
+  @Override
+  public Catalog createCatalog(Context context) {
+    FactoryUtil.CatalogFactoryHelper helper = FactoryUtil.createCatalogFactoryHelper(this, context);
+    helper.validate();
+
+    String catalogName = context.getName();
+    String warehouse = helper.getOptions().get(WAREHOUSE);
+    String defaultDatabase = helper.getOptions().get(DEFAULT_DATABASE);
+
+    // Collect storage configuration
+    Map storageOptions = new HashMap<>();
+
+    // S3 configuration
+    String accessKey = helper.getOptions().get(S3_ACCESS_KEY);
+    String secretKey = helper.getOptions().get(S3_SECRET_KEY);
+    String region = helper.getOptions().get(S3_REGION);
+    String endpoint = helper.getOptions().get(S3_ENDPOINT);
+    Boolean virtualHostedStyle = helper.getOptions().get(S3_VIRTUAL_HOSTED_STYLE);
+    Boolean allowHttp = helper.getOptions().get(S3_ALLOW_HTTP);
+
+    if (accessKey != null) {
+      storageOptions.put("aws_access_key_id", accessKey);
     }
-
-    @Override
-    public Set> requiredOptions() {
-        Set> options = new HashSet<>();
-        options.add(WAREHOUSE);
-        return options;
+    if (secretKey != null) {
+      storageOptions.put("aws_secret_access_key", secretKey);
     }
-
-    @Override
-    public Set> optionalOptions() {
-        Set> options = new HashSet<>();
-        options.add(DEFAULT_DATABASE);
-        // S3 related options
-        options.add(S3_ACCESS_KEY);
-        options.add(S3_SECRET_KEY);
-        options.add(S3_REGION);
-        options.add(S3_ENDPOINT);
-        options.add(S3_VIRTUAL_HOSTED_STYLE);
-        options.add(S3_ALLOW_HTTP);
-        return options;
+    if (region != null) {
+      storageOptions.put("aws_region", region);
     }
-
-    @Override
-    public Catalog createCatalog(Context context) {
-        FactoryUtil.CatalogFactoryHelper helper = FactoryUtil.createCatalogFactoryHelper(this, context);
-        helper.validate();
-
-        String catalogName = context.getName();
-        String warehouse = helper.getOptions().get(WAREHOUSE);
-        String defaultDatabase = helper.getOptions().get(DEFAULT_DATABASE);
-
-        // Collect storage configuration
-        Map storageOptions = new HashMap<>();
-        
-        // S3 configuration
-        String accessKey = helper.getOptions().get(S3_ACCESS_KEY);
-        String secretKey = helper.getOptions().get(S3_SECRET_KEY);
-        String region = helper.getOptions().get(S3_REGION);
-        String endpoint = helper.getOptions().get(S3_ENDPOINT);
-        Boolean virtualHostedStyle = helper.getOptions().get(S3_VIRTUAL_HOSTED_STYLE);
-        Boolean allowHttp = helper.getOptions().get(S3_ALLOW_HTTP);
-
-        if (accessKey != null) {
-            storageOptions.put("aws_access_key_id", accessKey);
-        }
-        if (secretKey != null) {
-            storageOptions.put("aws_secret_access_key", secretKey);
-        }
-        if (region != null) {
-            storageOptions.put("aws_region", region);
-        }
-        if (endpoint != null) {
-            storageOptions.put("aws_endpoint", endpoint);
-        }
-        if (virtualHostedStyle != null) {
-            storageOptions.put("aws_virtual_hosted_style_request", virtualHostedStyle.toString());
-        }
-        if (allowHttp != null) {
-            storageOptions.put("allow_http", allowHttp.toString());
-        }
-
-        return new LanceCatalog(catalogName, defaultDatabase, warehouse, storageOptions);
+    if (endpoint != null) {
+      storageOptions.put("aws_endpoint", endpoint);
     }
+    if (virtualHostedStyle != null) {
+      storageOptions.put("aws_virtual_hosted_style_request", virtualHostedStyle.toString());
+    }
+    if (allowHttp != null) {
+      storageOptions.put("allow_http", allowHttp.toString());
+    }
+
+    return new LanceCatalog(catalogName, defaultDatabase, warehouse, storageOptions);
+  }
 }
diff --git a/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableFactory.java b/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableFactory.java
index 4f4f358..03fe725 100644
--- a/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableFactory.java
+++ b/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableFactory.java
@@ -1,11 +1,7 @@
 /*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
@@ -15,7 +11,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.flink.connector.lance.table;
 
 import org.apache.flink.configuration.ConfigOption;
@@ -33,11 +28,12 @@
 
 /**
  * Lance dynamic table factory.
- * 
+ *
  * 

Implements Flink Table API DynamicTableSourceFactory and DynamicTableSinkFactory interfaces, * supports creating Lance tables via SQL DDL. - * + * *

Usage example: + * *

{@code
  * CREATE TABLE lance_table (
  *     id BIGINT,
@@ -49,189 +45,184 @@
  * );
  * }
*/ -public class LanceDynamicTableFactory implements DynamicTableSourceFactory, DynamicTableSinkFactory { - - public static final String IDENTIFIER = "lance"; - - // ==================== Configuration Options Definition ==================== - - public static final ConfigOption PATH = ConfigOptions - .key("path") - .stringType() - .noDefaultValue() - .withDescription("Lance dataset path"); - - public static final ConfigOption READ_BATCH_SIZE = ConfigOptions - .key("read.batch-size") - .intType() - .defaultValue(1024) - .withDescription("Read batch size"); - - public static final ConfigOption READ_COLUMNS = ConfigOptions - .key("read.columns") - .stringType() - .noDefaultValue() - .withDescription("Columns to read, comma separated"); - - public static final ConfigOption READ_FILTER = ConfigOptions - .key("read.filter") - .stringType() - .noDefaultValue() - .withDescription("Data filter condition"); - - public static final ConfigOption WRITE_BATCH_SIZE = ConfigOptions - .key("write.batch-size") - .intType() - .defaultValue(1024) - .withDescription("Write batch size"); - - public static final ConfigOption WRITE_MODE = ConfigOptions - .key("write.mode") - .stringType() - .defaultValue("append") - .withDescription("Write mode: append or overwrite"); - - public static final ConfigOption WRITE_MAX_ROWS_PER_FILE = ConfigOptions - .key("write.max-rows-per-file") - .intType() - .defaultValue(1000000) - .withDescription("Maximum rows per file"); - - public static final ConfigOption INDEX_TYPE = ConfigOptions - .key("index.type") - .stringType() - .defaultValue("IVF_PQ") - .withDescription("Vector index type"); - - public static final ConfigOption INDEX_COLUMN = ConfigOptions - .key("index.column") - .stringType() - .noDefaultValue() - .withDescription("Index column name"); - - public static final ConfigOption INDEX_NUM_PARTITIONS = ConfigOptions - .key("index.num-partitions") - .intType() - .defaultValue(256) - .withDescription("IVF partition count"); - - public static final ConfigOption INDEX_NUM_SUB_VECTORS = ConfigOptions - .key("index.num-sub-vectors") - .intType() - .noDefaultValue() - .withDescription("PQ sub-vector count"); - - public static final ConfigOption VECTOR_COLUMN = ConfigOptions - .key("vector.column") - .stringType() - .noDefaultValue() - .withDescription("Vector column name"); - - public static final ConfigOption VECTOR_METRIC = ConfigOptions - .key("vector.metric") - .stringType() - .defaultValue("L2") - .withDescription("Distance metric type: L2, Cosine, Dot"); - - public static final ConfigOption VECTOR_NPROBES = ConfigOptions - .key("vector.nprobes") - .intType() - .defaultValue(20) - .withDescription("IVF search probe count"); - - @Override - public String factoryIdentifier() { - return IDENTIFIER; - } - - @Override - public Set> requiredOptions() { - Set> options = new HashSet<>(); - options.add(PATH); - return options; - } - - @Override - public Set> optionalOptions() { - Set> options = new HashSet<>(); - options.add(READ_BATCH_SIZE); - options.add(READ_COLUMNS); - options.add(READ_FILTER); - options.add(WRITE_BATCH_SIZE); - options.add(WRITE_MODE); - options.add(WRITE_MAX_ROWS_PER_FILE); - options.add(INDEX_TYPE); - options.add(INDEX_COLUMN); - options.add(INDEX_NUM_PARTITIONS); - options.add(INDEX_NUM_SUB_VECTORS); - options.add(VECTOR_COLUMN); - options.add(VECTOR_METRIC); - options.add(VECTOR_NPROBES); - return options; - } - - @Override - public DynamicTableSource createDynamicTableSource(Context context) { - FactoryUtil.TableFactoryHelper helper = FactoryUtil.createTableFactoryHelper(this, context); - helper.validate(); - - ReadableConfig config = helper.getOptions(); - LanceOptions options = buildLanceOptions(config); - - return new LanceDynamicTableSource( - options, - context.getCatalogTable().getResolvedSchema().toPhysicalRowDataType() - ); - } - - @Override - public DynamicTableSink createDynamicTableSink(Context context) { - FactoryUtil.TableFactoryHelper helper = FactoryUtil.createTableFactoryHelper(this, context); - helper.validate(); - - ReadableConfig config = helper.getOptions(); - LanceOptions options = buildLanceOptions(config); - - return new LanceDynamicTableSink( - options, - context.getCatalogTable().getResolvedSchema().toPhysicalRowDataType() - ); - } - - /** - * Build LanceOptions from configuration - */ - private LanceOptions buildLanceOptions(ReadableConfig config) { - LanceOptions.Builder builder = LanceOptions.builder(); - - // Common configuration - builder.path(config.get(PATH)); - - // Source configuration - builder.readBatchSize(config.get(READ_BATCH_SIZE)); - config.getOptional(READ_COLUMNS).ifPresent(columns -> { - if (!columns.isEmpty()) { +public class LanceDynamicTableFactory + implements DynamicTableSourceFactory, DynamicTableSinkFactory { + + public static final String IDENTIFIER = "lance"; + + // ==================== Configuration Options Definition ==================== + + public static final ConfigOption PATH = + ConfigOptions.key("path").stringType().noDefaultValue().withDescription("Lance dataset path"); + + public static final ConfigOption READ_BATCH_SIZE = + ConfigOptions.key("read.batch-size") + .intType() + .defaultValue(1024) + .withDescription("Read batch size"); + + public static final ConfigOption READ_COLUMNS = + ConfigOptions.key("read.columns") + .stringType() + .noDefaultValue() + .withDescription("Columns to read, comma separated"); + + public static final ConfigOption READ_FILTER = + ConfigOptions.key("read.filter") + .stringType() + .noDefaultValue() + .withDescription("Data filter condition"); + + public static final ConfigOption WRITE_BATCH_SIZE = + ConfigOptions.key("write.batch-size") + .intType() + .defaultValue(1024) + .withDescription("Write batch size"); + + public static final ConfigOption WRITE_MODE = + ConfigOptions.key("write.mode") + .stringType() + .defaultValue("append") + .withDescription("Write mode: append or overwrite"); + + public static final ConfigOption WRITE_MAX_ROWS_PER_FILE = + ConfigOptions.key("write.max-rows-per-file") + .intType() + .defaultValue(1000000) + .withDescription("Maximum rows per file"); + + public static final ConfigOption INDEX_TYPE = + ConfigOptions.key("index.type") + .stringType() + .defaultValue("IVF_PQ") + .withDescription("Vector index type"); + + public static final ConfigOption INDEX_COLUMN = + ConfigOptions.key("index.column") + .stringType() + .noDefaultValue() + .withDescription("Index column name"); + + public static final ConfigOption INDEX_NUM_PARTITIONS = + ConfigOptions.key("index.num-partitions") + .intType() + .defaultValue(256) + .withDescription("IVF partition count"); + + public static final ConfigOption INDEX_NUM_SUB_VECTORS = + ConfigOptions.key("index.num-sub-vectors") + .intType() + .noDefaultValue() + .withDescription("PQ sub-vector count"); + + public static final ConfigOption VECTOR_COLUMN = + ConfigOptions.key("vector.column") + .stringType() + .noDefaultValue() + .withDescription("Vector column name"); + + public static final ConfigOption VECTOR_METRIC = + ConfigOptions.key("vector.metric") + .stringType() + .defaultValue("L2") + .withDescription("Distance metric type: L2, Cosine, Dot"); + + public static final ConfigOption VECTOR_NPROBES = + ConfigOptions.key("vector.nprobes") + .intType() + .defaultValue(20) + .withDescription("IVF search probe count"); + + @Override + public String factoryIdentifier() { + return IDENTIFIER; + } + + @Override + public Set> requiredOptions() { + Set> options = new HashSet<>(); + options.add(PATH); + return options; + } + + @Override + public Set> optionalOptions() { + Set> options = new HashSet<>(); + options.add(READ_BATCH_SIZE); + options.add(READ_COLUMNS); + options.add(READ_FILTER); + options.add(WRITE_BATCH_SIZE); + options.add(WRITE_MODE); + options.add(WRITE_MAX_ROWS_PER_FILE); + options.add(INDEX_TYPE); + options.add(INDEX_COLUMN); + options.add(INDEX_NUM_PARTITIONS); + options.add(INDEX_NUM_SUB_VECTORS); + options.add(VECTOR_COLUMN); + options.add(VECTOR_METRIC); + options.add(VECTOR_NPROBES); + return options; + } + + @Override + public DynamicTableSource createDynamicTableSource(Context context) { + FactoryUtil.TableFactoryHelper helper = FactoryUtil.createTableFactoryHelper(this, context); + helper.validate(); + + ReadableConfig config = helper.getOptions(); + LanceOptions options = buildLanceOptions(config); + + return new LanceDynamicTableSource( + options, context.getCatalogTable().getResolvedSchema().toPhysicalRowDataType()); + } + + @Override + public DynamicTableSink createDynamicTableSink(Context context) { + FactoryUtil.TableFactoryHelper helper = FactoryUtil.createTableFactoryHelper(this, context); + helper.validate(); + + ReadableConfig config = helper.getOptions(); + LanceOptions options = buildLanceOptions(config); + + return new LanceDynamicTableSink( + options, context.getCatalogTable().getResolvedSchema().toPhysicalRowDataType()); + } + + /** Build LanceOptions from configuration */ + private LanceOptions buildLanceOptions(ReadableConfig config) { + LanceOptions.Builder builder = LanceOptions.builder(); + + // Common configuration + builder.path(config.get(PATH)); + + // Source configuration + builder.readBatchSize(config.get(READ_BATCH_SIZE)); + config + .getOptional(READ_COLUMNS) + .ifPresent( + columns -> { + if (!columns.isEmpty()) { builder.readColumns(java.util.Arrays.asList(columns.split(","))); - } - }); - config.getOptional(READ_FILTER).ifPresent(builder::readFilter); - - // Sink configuration - builder.writeBatchSize(config.get(WRITE_BATCH_SIZE)); - builder.writeMode(LanceOptions.WriteMode.fromValue(config.get(WRITE_MODE))); - builder.writeMaxRowsPerFile(config.get(WRITE_MAX_ROWS_PER_FILE)); - - // Index configuration - builder.indexType(LanceOptions.IndexType.fromValue(config.get(INDEX_TYPE))); - config.getOptional(INDEX_COLUMN).ifPresent(builder::indexColumn); - builder.indexNumPartitions(config.get(INDEX_NUM_PARTITIONS)); - config.getOptional(INDEX_NUM_SUB_VECTORS).ifPresent(builder::indexNumSubVectors); - - // Vector search configuration - config.getOptional(VECTOR_COLUMN).ifPresent(builder::vectorColumn); - builder.vectorMetric(LanceOptions.MetricType.fromValue(config.get(VECTOR_METRIC))); - builder.vectorNprobes(config.get(VECTOR_NPROBES)); - - return builder.build(); - } + } + }); + config.getOptional(READ_FILTER).ifPresent(builder::readFilter); + + // Sink configuration + builder.writeBatchSize(config.get(WRITE_BATCH_SIZE)); + builder.writeMode(LanceOptions.WriteMode.fromValue(config.get(WRITE_MODE))); + builder.writeMaxRowsPerFile(config.get(WRITE_MAX_ROWS_PER_FILE)); + + // Index configuration + builder.indexType(LanceOptions.IndexType.fromValue(config.get(INDEX_TYPE))); + config.getOptional(INDEX_COLUMN).ifPresent(builder::indexColumn); + builder.indexNumPartitions(config.get(INDEX_NUM_PARTITIONS)); + config.getOptional(INDEX_NUM_SUB_VECTORS).ifPresent(builder::indexNumSubVectors); + + // Vector search configuration + config.getOptional(VECTOR_COLUMN).ifPresent(builder::vectorColumn); + builder.vectorMetric(LanceOptions.MetricType.fromValue(config.get(VECTOR_METRIC))); + builder.vectorNprobes(config.get(VECTOR_NPROBES)); + + return builder.build(); + } } diff --git a/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableSink.java b/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableSink.java index 09ed914..89b0a49 100644 --- a/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableSink.java +++ b/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableSink.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,77 +11,65 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.table; import org.apache.flink.connector.lance.LanceSink; import org.apache.flink.connector.lance.config.LanceOptions; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.datastream.DataStreamSink; -import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.table.connector.ChangelogMode; -import org.apache.flink.table.connector.sink.DataStreamSinkProvider; import org.apache.flink.table.connector.sink.DynamicTableSink; import org.apache.flink.table.connector.sink.SinkFunctionProvider; -import org.apache.flink.table.data.RowData; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.types.RowKind; /** * Lance dynamic table sink. - * + * *

Implements DynamicTableSink interface, supports writing Flink data to Lance dataset. */ public class LanceDynamicTableSink implements DynamicTableSink { - private final LanceOptions options; - private final DataType physicalDataType; + private final LanceOptions options; + private final DataType physicalDataType; - public LanceDynamicTableSink(LanceOptions options, DataType physicalDataType) { - this.options = options; - this.physicalDataType = physicalDataType; - } + public LanceDynamicTableSink(LanceOptions options, DataType physicalDataType) { + this.options = options; + this.physicalDataType = physicalDataType; + } - @Override - public ChangelogMode getChangelogMode(ChangelogMode requestedMode) { - // Lance only supports INSERT operations - return ChangelogMode.newBuilder() - .addContainedKind(RowKind.INSERT) - .build(); - } + @Override + public ChangelogMode getChangelogMode(ChangelogMode requestedMode) { + // Lance only supports INSERT operations + return ChangelogMode.newBuilder().addContainedKind(RowKind.INSERT).build(); + } - @Override - public SinkRuntimeProvider getSinkRuntimeProvider(Context context) { - RowType rowType = (RowType) physicalDataType.getLogicalType(); + @Override + public SinkRuntimeProvider getSinkRuntimeProvider(Context context) { + RowType rowType = (RowType) physicalDataType.getLogicalType(); - // Create LanceSink - LanceSink lanceSink = new LanceSink(options, rowType); + // Create LanceSink + LanceSink lanceSink = new LanceSink(options, rowType); - return SinkFunctionProvider.of(lanceSink); - } + return SinkFunctionProvider.of(lanceSink); + } - @Override - public DynamicTableSink copy() { - return new LanceDynamicTableSink(options, physicalDataType); - } + @Override + public DynamicTableSink copy() { + return new LanceDynamicTableSink(options, physicalDataType); + } - @Override - public String asSummaryString() { - return "Lance Table Sink"; - } + @Override + public String asSummaryString() { + return "Lance Table Sink"; + } - /** - * Get configuration options - */ - public LanceOptions getOptions() { - return options; - } + /** Get configuration options */ + public LanceOptions getOptions() { + return options; + } - /** - * Get physical data type - */ - public DataType getPhysicalDataType() { - return physicalDataType; - } + /** Get physical data type */ + public DataType getPhysicalDataType() { + return physicalDataType; + } } diff --git a/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableSource.java b/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableSource.java index dfcf186..75c99ad 100644 --- a/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableSource.java +++ b/src/main/java/org/apache/flink/connector/lance/table/LanceDynamicTableSource.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,21 +11,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.table; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.connector.lance.LanceInputFormat; import org.apache.flink.connector.lance.LanceSource; import org.apache.flink.connector.lance.aggregate.AggregateInfo; import org.apache.flink.connector.lance.config.LanceOptions; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.table.api.TableSchema; import org.apache.flink.table.connector.ChangelogMode; import org.apache.flink.table.connector.source.DataStreamScanProvider; import org.apache.flink.table.connector.source.DynamicTableSource; -import org.apache.flink.table.connector.source.InputFormatProvider; import org.apache.flink.table.connector.source.ScanTableSource; import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown; import org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown; @@ -44,9 +35,7 @@ import org.apache.flink.table.functions.BuiltInFunctionDefinitions; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.types.DataType; -import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.RowType; -import org.apache.flink.types.RowKind; import java.util.ArrayList; import java.util.Arrays; @@ -55,473 +44,457 @@ /** * Lance dynamic table source. - * + * *

Implements ScanTableSource interface, supports column pruning and filter push-down. */ -public class LanceDynamicTableSource implements ScanTableSource, - SupportsProjectionPushDown, SupportsFilterPushDown, SupportsLimitPushDown, +public class LanceDynamicTableSource + implements ScanTableSource, + SupportsProjectionPushDown, + SupportsFilterPushDown, + SupportsLimitPushDown, SupportsAggregatePushDown { - private final LanceOptions options; - private final DataType physicalDataType; - private int[] projectedFields; - private List filters; - private Long limit; // Limit push-down - private AggregateInfo aggregateInfo; // Aggregate push-down - private boolean aggregatePushDownAccepted; // Whether aggregate push-down is accepted - - public LanceDynamicTableSource(LanceOptions options, DataType physicalDataType) { - this.options = options; - this.physicalDataType = physicalDataType; - this.projectedFields = null; - this.filters = new ArrayList<>(); - this.limit = null; - this.aggregateInfo = null; - this.aggregatePushDownAccepted = false; + private final LanceOptions options; + private final DataType physicalDataType; + private int[] projectedFields; + private List filters; + private Long limit; // Limit push-down + private AggregateInfo aggregateInfo; // Aggregate push-down + private boolean aggregatePushDownAccepted; // Whether aggregate push-down is accepted + + public LanceDynamicTableSource(LanceOptions options, DataType physicalDataType) { + this.options = options; + this.physicalDataType = physicalDataType; + this.projectedFields = null; + this.filters = new ArrayList<>(); + this.limit = null; + this.aggregateInfo = null; + this.aggregatePushDownAccepted = false; + } + + private LanceDynamicTableSource(LanceDynamicTableSource source) { + this.options = source.options; + this.physicalDataType = source.physicalDataType; + this.projectedFields = source.projectedFields; + this.filters = new ArrayList<>(source.filters); + this.limit = source.limit; + this.aggregateInfo = source.aggregateInfo; + this.aggregatePushDownAccepted = source.aggregatePushDownAccepted; + } + + @Override + public ChangelogMode getChangelogMode() { + return ChangelogMode.insertOnly(); + } + + @Override + public ScanRuntimeProvider getScanRuntimeProvider(ScanContext runtimeProviderContext) { + RowType rowType = (RowType) physicalDataType.getLogicalType(); + + // If column pruning applied, build new RowType + RowType projectedRowType = rowType; + if (projectedFields != null) { + List projectedFieldList = new ArrayList<>(); + for (int fieldIndex : projectedFields) { + projectedFieldList.add(rowType.getFields().get(fieldIndex)); + } + projectedRowType = new RowType(projectedFieldList); } - private LanceDynamicTableSource(LanceDynamicTableSource source) { - this.options = source.options; - this.physicalDataType = source.physicalDataType; - this.projectedFields = source.projectedFields; - this.filters = new ArrayList<>(source.filters); - this.limit = source.limit; - this.aggregateInfo = source.aggregateInfo; - this.aggregatePushDownAccepted = source.aggregatePushDownAccepted; - } + // Build LanceOptions (apply column pruning and filter conditions) + LanceOptions.Builder optionsBuilder = + LanceOptions.builder() + .path(options.getPath()) + .readBatchSize(options.getReadBatchSize()) + .readFilter(buildFilterExpression()); - @Override - public ChangelogMode getChangelogMode() { - return ChangelogMode.insertOnly(); + // Set Limit (if any) + if (limit != null) { + optionsBuilder.readLimit(limit); } - @Override - public ScanRuntimeProvider getScanRuntimeProvider(ScanContext runtimeProviderContext) { - RowType rowType = (RowType) physicalDataType.getLogicalType(); - - // If column pruning applied, build new RowType - RowType projectedRowType = rowType; - if (projectedFields != null) { - List projectedFieldList = new ArrayList<>(); - for (int fieldIndex : projectedFields) { - projectedFieldList.add(rowType.getFields().get(fieldIndex)); - } - projectedRowType = new RowType(projectedFieldList); - } + // Set columns to read + if (projectedFields != null) { + List columnNames = + Arrays.stream(projectedFields) + .mapToObj(i -> rowType.getFieldNames().get(i)) + .collect(Collectors.toList()); + optionsBuilder.readColumns(columnNames); + } - // Build LanceOptions (apply column pruning and filter conditions) - LanceOptions.Builder optionsBuilder = LanceOptions.builder() - .path(options.getPath()) - .readBatchSize(options.getReadBatchSize()) - .readFilter(buildFilterExpression()); + LanceOptions finalOptions = optionsBuilder.build(); + final RowType finalRowType = projectedRowType; + + // Use DataStreamScanProvider + return new DataStreamScanProvider() { + @Override + public DataStream produceDataStream(StreamExecutionEnvironment execEnv) { + LanceSource source = new LanceSource(finalOptions, finalRowType); + return execEnv.addSource(source, "LanceSource"); + } + + @Override + public boolean isBounded() { + return true; // Lance dataset is bounded + } + }; + } + + @Override + public DynamicTableSource copy() { + return new LanceDynamicTableSource(this); + } + + @Override + public String asSummaryString() { + return "Lance Table Source"; + } + + // ==================== SupportsProjectionPushDown ==================== + + @Override + public boolean supportsNestedProjection() { + return false; + } + + @Override + public void applyProjection(int[][] projectedFields) { + // Only support top-level field projection + this.projectedFields = Arrays.stream(projectedFields).mapToInt(arr -> arr[0]).toArray(); + } + + // ==================== SupportsFilterPushDown ==================== + + @Override + public Result applyFilters(List filters) { + // Convert Flink expressions to Lance filter conditions + List acceptedFilters = new ArrayList<>(); + List remainingFilters = new ArrayList<>(); + + for (ResolvedExpression filter : filters) { + String lanceFilter = convertToLanceFilter(filter); + if (lanceFilter != null) { + this.filters.add(lanceFilter); + acceptedFilters.add(filter); + } else { + remainingFilters.add(filter); + } + } - // Set Limit (if any) - if (limit != null) { - optionsBuilder.readLimit(limit); + return Result.of(acceptedFilters, remainingFilters); + } + + /** + * Convert Flink expression to Lance filter condition. Lance supports standard SQL filter syntax, + * e.g., column = 'value', column > 10 + */ + private String convertToLanceFilter(ResolvedExpression expression) { + try { + if (expression instanceof CallExpression) { + CallExpression callExpr = (CallExpression) expression; + return convertCallExpression(callExpr); + } + // Other expression types not supported for push-down + return null; + } catch (Exception e) { + // Return null for unconvertible expressions, handled by Flink at upper layer + return null; + } + } + + /** Convert CallExpression to Lance filter string */ + private String convertCallExpression(CallExpression callExpr) { + FunctionDefinition funcDef = callExpr.getFunctionDefinition(); + List args = callExpr.getResolvedChildren(); + + // Comparison operators + if (funcDef == BuiltInFunctionDefinitions.EQUALS) { + return buildComparisonFilter(args, "="); + } else if (funcDef == BuiltInFunctionDefinitions.NOT_EQUALS) { + return buildComparisonFilter(args, "!="); + } else if (funcDef == BuiltInFunctionDefinitions.GREATER_THAN) { + return buildComparisonFilter(args, ">"); + } else if (funcDef == BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL) { + return buildComparisonFilter(args, ">="); + } else if (funcDef == BuiltInFunctionDefinitions.LESS_THAN) { + return buildComparisonFilter(args, "<"); + } else if (funcDef == BuiltInFunctionDefinitions.LESS_THAN_OR_EQUAL) { + return buildComparisonFilter(args, "<="); + } + // Logical operators + else if (funcDef == BuiltInFunctionDefinitions.AND) { + return buildLogicalFilter(args, "AND"); + } else if (funcDef == BuiltInFunctionDefinitions.OR) { + return buildLogicalFilter(args, "OR"); + } else if (funcDef == BuiltInFunctionDefinitions.NOT) { + if (args.size() == 1) { + String inner = convertToLanceFilter(args.get(0)); + if (inner != null) { + return "NOT (" + inner + ")"; } + } + } + // IS NULL / IS NOT NULL + else if (funcDef == BuiltInFunctionDefinitions.IS_NULL) { + if (args.size() == 1 && args.get(0) instanceof FieldReferenceExpression) { + String fieldName = ((FieldReferenceExpression) args.get(0)).getName(); + return fieldName + " IS NULL"; + } + } else if (funcDef == BuiltInFunctionDefinitions.IS_NOT_NULL) { + if (args.size() == 1 && args.get(0) instanceof FieldReferenceExpression) { + String fieldName = ((FieldReferenceExpression) args.get(0)).getName(); + return fieldName + " IS NOT NULL"; + } + } + // LIKE + else if (funcDef == BuiltInFunctionDefinitions.LIKE) { + return buildComparisonFilter(args, "LIKE"); + } + // IN (not supported yet, requires more complex handling) + // BETWEEN (not supported yet) - // Set columns to read - if (projectedFields != null) { - List columnNames = Arrays.stream(projectedFields) - .mapToObj(i -> rowType.getFieldNames().get(i)) - .collect(Collectors.toList()); - optionsBuilder.readColumns(columnNames); - } + // Unsupported functions, return null + return null; + } - LanceOptions finalOptions = optionsBuilder.build(); - final RowType finalRowType = projectedRowType; - - // Use DataStreamScanProvider - return new DataStreamScanProvider() { - @Override - public DataStream produceDataStream(StreamExecutionEnvironment execEnv) { - LanceSource source = new LanceSource(finalOptions, finalRowType); - return execEnv.addSource(source, "LanceSource"); - } - - @Override - public boolean isBounded() { - return true; // Lance dataset is bounded - } - }; + /** Build comparison filter expression */ + private String buildComparisonFilter(List args, String operator) { + if (args.size() != 2) { + return null; } - @Override - public DynamicTableSource copy() { - return new LanceDynamicTableSource(this); + ResolvedExpression left = args.get(0); + ResolvedExpression right = args.get(1); + + // Extract field name and value + String fieldName = null; + String value = null; + + if (left instanceof FieldReferenceExpression) { + fieldName = ((FieldReferenceExpression) left).getName(); + value = extractLiteralValue(right); + } else if (right instanceof FieldReferenceExpression) { + fieldName = ((FieldReferenceExpression) right).getName(); + value = extractLiteralValue(left); + // For asymmetric operators, need to swap operator + if (">".equals(operator)) { + operator = "<"; + } else if ("<".equals(operator)) { + operator = ">"; + } else if (">=".equals(operator)) { + operator = "<="; + } else if ("<=".equals(operator)) { + operator = ">="; + } } - @Override - public String asSummaryString() { - return "Lance Table Source"; + if (fieldName != null && value != null) { + return fieldName + " " + operator + " " + value; } - // ==================== SupportsProjectionPushDown ==================== - - @Override - public boolean supportsNestedProjection() { - return false; + return null; + } + + /** Build logical filter expression */ + private String buildLogicalFilter(List args, String operator) { + List convertedArgs = new ArrayList<>(); + for (ResolvedExpression arg : args) { + String converted = convertToLanceFilter(arg); + if (converted == null) { + return null; // If any sub-expression cannot be converted, don't push down entire expression + } + convertedArgs.add("(" + converted + ")"); + } + return String.join(" " + operator + " ", convertedArgs); + } + + /** Extract literal value from ValueLiteralExpression */ + private String extractLiteralValue(ResolvedExpression expr) { + if (expr instanceof ValueLiteralExpression) { + ValueLiteralExpression literal = (ValueLiteralExpression) expr; + Object value = literal.getValueAs(Object.class).orElse(null); + + if (value == null) { + return "NULL"; + } else if (value instanceof String) { + // Strings need single quotes and escape internal single quotes + String strValue = (String) value; + strValue = strValue.replace("'", "''"); + return "'" + strValue + "'"; + } else if (value instanceof Number) { + return value.toString(); + } else if (value instanceof Boolean) { + return value.toString().toUpperCase(); + } else { + // Other types try to convert to string + return "'" + value.toString().replace("'", "''") + "'"; + } } + return null; + } - @Override - public void applyProjection(int[][] projectedFields) { - // Only support top-level field projection - this.projectedFields = Arrays.stream(projectedFields) - .mapToInt(arr -> arr[0]) - .toArray(); + /** Build filter expression */ + private String buildFilterExpression() { + if (filters.isEmpty()) { + return options.getReadFilter(); } - // ==================== SupportsFilterPushDown ==================== - - @Override - public Result applyFilters(List filters) { - // Convert Flink expressions to Lance filter conditions - List acceptedFilters = new ArrayList<>(); - List remainingFilters = new ArrayList<>(); - - for (ResolvedExpression filter : filters) { - String lanceFilter = convertToLanceFilter(filter); - if (lanceFilter != null) { - this.filters.add(lanceFilter); - acceptedFilters.add(filter); - } else { - remainingFilters.add(filter); - } - } + String combinedFilter = String.join(" AND ", filters); + String originalFilter = options.getReadFilter(); - return Result.of(acceptedFilters, remainingFilters); + if (originalFilter != null && !originalFilter.isEmpty()) { + return "(" + originalFilter + ") AND (" + combinedFilter + ")"; } - /** - * Convert Flink expression to Lance filter condition. - * Lance supports standard SQL filter syntax, e.g., column = 'value', column > 10 - */ - private String convertToLanceFilter(ResolvedExpression expression) { - try { - if (expression instanceof CallExpression) { - CallExpression callExpr = (CallExpression) expression; - return convertCallExpression(callExpr); - } - // Other expression types not supported for push-down - return null; - } catch (Exception e) { - // Return null for unconvertible expressions, handled by Flink at upper layer - return null; - } - } + return combinedFilter; + } - /** - * Convert CallExpression to Lance filter string - */ - private String convertCallExpression(CallExpression callExpr) { - FunctionDefinition funcDef = callExpr.getFunctionDefinition(); - List args = callExpr.getResolvedChildren(); - - // Comparison operators - if (funcDef == BuiltInFunctionDefinitions.EQUALS) { - return buildComparisonFilter(args, "="); - } else if (funcDef == BuiltInFunctionDefinitions.NOT_EQUALS) { - return buildComparisonFilter(args, "!="); - } else if (funcDef == BuiltInFunctionDefinitions.GREATER_THAN) { - return buildComparisonFilter(args, ">"); - } else if (funcDef == BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL) { - return buildComparisonFilter(args, ">="); - } else if (funcDef == BuiltInFunctionDefinitions.LESS_THAN) { - return buildComparisonFilter(args, "<"); - } else if (funcDef == BuiltInFunctionDefinitions.LESS_THAN_OR_EQUAL) { - return buildComparisonFilter(args, "<="); - } - // Logical operators - else if (funcDef == BuiltInFunctionDefinitions.AND) { - return buildLogicalFilter(args, "AND"); - } else if (funcDef == BuiltInFunctionDefinitions.OR) { - return buildLogicalFilter(args, "OR"); - } else if (funcDef == BuiltInFunctionDefinitions.NOT) { - if (args.size() == 1) { - String inner = convertToLanceFilter(args.get(0)); - if (inner != null) { - return "NOT (" + inner + ")"; - } - } - } - // IS NULL / IS NOT NULL - else if (funcDef == BuiltInFunctionDefinitions.IS_NULL) { - if (args.size() == 1 && args.get(0) instanceof FieldReferenceExpression) { - String fieldName = ((FieldReferenceExpression) args.get(0)).getName(); - return fieldName + " IS NULL"; - } - } else if (funcDef == BuiltInFunctionDefinitions.IS_NOT_NULL) { - if (args.size() == 1 && args.get(0) instanceof FieldReferenceExpression) { - String fieldName = ((FieldReferenceExpression) args.get(0)).getName(); - return fieldName + " IS NOT NULL"; - } - } - // LIKE - else if (funcDef == BuiltInFunctionDefinitions.LIKE) { - return buildComparisonFilter(args, "LIKE"); - } - // IN (not supported yet, requires more complex handling) - // BETWEEN (not supported yet) + /** Get configuration options */ + public LanceOptions getOptions() { + return options; + } - // Unsupported functions, return null - return null; - } + /** Get physical data type */ + public DataType getPhysicalDataType() { + return physicalDataType; + } - /** - * Build comparison filter expression - */ - private String buildComparisonFilter(List args, String operator) { - if (args.size() != 2) { - return null; - } + // ==================== SupportsLimitPushDown ==================== - ResolvedExpression left = args.get(0); - ResolvedExpression right = args.get(1); - - // Extract field name and value - String fieldName = null; - String value = null; - - if (left instanceof FieldReferenceExpression) { - fieldName = ((FieldReferenceExpression) left).getName(); - value = extractLiteralValue(right); - } else if (right instanceof FieldReferenceExpression) { - fieldName = ((FieldReferenceExpression) right).getName(); - value = extractLiteralValue(left); - // For asymmetric operators, need to swap operator - if (">".equals(operator)) operator = "<"; - else if ("<".equals(operator)) operator = ">"; - else if (">=".equals(operator)) operator = "<="; - else if ("<=".equals(operator)) operator = ">="; - } + @Override + public void applyLimit(long limit) { + this.limit = limit; + } - if (fieldName != null && value != null) { - return fieldName + " " + operator + " " + value; - } + /** Get Limit value */ + public Long getLimit() { + return limit; + } - return null; - } + // ==================== SupportsAggregatePushDown ==================== - /** - * Build logical filter expression - */ - private String buildLogicalFilter(List args, String operator) { - List convertedArgs = new ArrayList<>(); - for (ResolvedExpression arg : args) { - String converted = convertToLanceFilter(arg); - if (converted == null) { - return null; // If any sub-expression cannot be converted, don't push down entire expression - } - convertedArgs.add("(" + converted + ")"); - } - return String.join(" " + operator + " ", convertedArgs); - } + @Override + public boolean applyAggregates( + List groupingSets, + List aggregateExpressions, + DataType producedDataType) { - /** - * Extract literal value from ValueLiteralExpression - */ - private String extractLiteralValue(ResolvedExpression expr) { - if (expr instanceof ValueLiteralExpression) { - ValueLiteralExpression literal = (ValueLiteralExpression) expr; - Object value = literal.getValueAs(Object.class).orElse(null); - - if (value == null) { - return "NULL"; - } else if (value instanceof String) { - // Strings need single quotes and escape internal single quotes - String strValue = (String) value; - strValue = strValue.replace("'", "''"); - return "'" + strValue + "'"; - } else if (value instanceof Number) { - return value.toString(); - } else if (value instanceof Boolean) { - return value.toString().toUpperCase(); - } else { - // Other types try to convert to string - return "'" + value.toString().replace("'", "''") + "'"; - } - } - return null; + // Currently only support simple single grouping set + if (groupingSets.size() != 1) { + return false; } - /** - * Build filter expression - */ - private String buildFilterExpression() { - if (filters.isEmpty()) { - return options.getReadFilter(); - } + int[] groupingSet = groupingSets.get(0); + RowType rowType = (RowType) physicalDataType.getLogicalType(); + List fieldNames = rowType.getFieldNames(); - String combinedFilter = String.join(" AND ", filters); - String originalFilter = options.getReadFilter(); + try { + AggregateInfo.Builder builder = AggregateInfo.builder(); - if (originalFilter != null && !originalFilter.isEmpty()) { - return "(" + originalFilter + ") AND (" + combinedFilter + ")"; + // Handle grouping columns + List groupByColumns = new ArrayList<>(); + for (int fieldIndex : groupingSet) { + if (fieldIndex >= 0 && fieldIndex < fieldNames.size()) { + groupByColumns.add(fieldNames.get(fieldIndex)); } + } + builder.groupBy(groupByColumns); + builder.groupByFieldIndices(groupingSet); + + // Handle aggregate expressions + int aggIndex = 0; + for (AggregateExpression aggExpr : aggregateExpressions) { + AggregateInfo.AggregateCall aggCall = + convertAggregateExpression(aggExpr, fieldNames, aggIndex++); + if (aggCall == null) { + // Unsupported aggregate function, reject push-down + return false; + } + builder.addAggregateCall(aggCall); + } - return combinedFilter; - } + this.aggregateInfo = builder.build(); + this.aggregatePushDownAccepted = true; + return true; - /** - * Get configuration options - */ - public LanceOptions getOptions() { - return options; + } catch (Exception e) { + // Conversion failed, reject push-down + return false; } + } - /** - * Get physical data type - */ - public DataType getPhysicalDataType() { - return physicalDataType; - } + /** Convert Flink aggregate expression to internal aggregate call */ + private AggregateInfo.AggregateCall convertAggregateExpression( + AggregateExpression aggExpr, List fieldNames, int aggIndex) { - // ==================== SupportsLimitPushDown ==================== + FunctionDefinition funcDef = aggExpr.getFunctionDefinition(); + List args = aggExpr.getArgs(); + String alias = "agg_" + aggIndex; - @Override - public void applyLimit(long limit) { - this.limit = limit; + // COUNT(*) + if (funcDef == BuiltInFunctionDefinitions.COUNT) { + if (args.isEmpty()) { + // COUNT(*) + return new AggregateInfo.AggregateCall(AggregateInfo.AggregateFunction.COUNT, null, alias); + } else { + // COUNT(column) + String columnName = args.get(0).getName(); + return new AggregateInfo.AggregateCall( + AggregateInfo.AggregateFunction.COUNT, columnName, alias); + } } - /** - * Get Limit value - */ - public Long getLimit() { - return limit; + // SUM + if (funcDef == BuiltInFunctionDefinitions.SUM || funcDef == BuiltInFunctionDefinitions.SUM0) { + if (args.isEmpty()) { + return null; + } + String columnName = args.get(0).getName(); + return new AggregateInfo.AggregateCall( + AggregateInfo.AggregateFunction.SUM, columnName, alias); } - // ==================== SupportsAggregatePushDown ==================== - - @Override - public boolean applyAggregates( - List groupingSets, - List aggregateExpressions, - DataType producedDataType) { - - // Currently only support simple single grouping set - if (groupingSets.size() != 1) { - return false; - } - - int[] groupingSet = groupingSets.get(0); - RowType rowType = (RowType) physicalDataType.getLogicalType(); - List fieldNames = rowType.getFieldNames(); - - try { - AggregateInfo.Builder builder = AggregateInfo.builder(); - - // Handle grouping columns - List groupByColumns = new ArrayList<>(); - for (int fieldIndex : groupingSet) { - if (fieldIndex >= 0 && fieldIndex < fieldNames.size()) { - groupByColumns.add(fieldNames.get(fieldIndex)); - } - } - builder.groupBy(groupByColumns); - builder.groupByFieldIndices(groupingSet); - - // Handle aggregate expressions - int aggIndex = 0; - for (AggregateExpression aggExpr : aggregateExpressions) { - AggregateInfo.AggregateCall aggCall = convertAggregateExpression(aggExpr, fieldNames, aggIndex++); - if (aggCall == null) { - // Unsupported aggregate function, reject push-down - return false; - } - builder.addAggregateCall(aggCall); - } - - this.aggregateInfo = builder.build(); - this.aggregatePushDownAccepted = true; - return true; - - } catch (Exception e) { - // Conversion failed, reject push-down - return false; - } + // AVG + if (funcDef == BuiltInFunctionDefinitions.AVG) { + if (args.isEmpty()) { + return null; + } + String columnName = args.get(0).getName(); + return new AggregateInfo.AggregateCall( + AggregateInfo.AggregateFunction.AVG, columnName, alias); } - /** - * Convert Flink aggregate expression to internal aggregate call - */ - private AggregateInfo.AggregateCall convertAggregateExpression( - AggregateExpression aggExpr, - List fieldNames, - int aggIndex) { - - FunctionDefinition funcDef = aggExpr.getFunctionDefinition(); - List args = aggExpr.getArgs(); - String alias = "agg_" + aggIndex; - - // COUNT(*) - if (funcDef == BuiltInFunctionDefinitions.COUNT) { - if (args.isEmpty()) { - // COUNT(*) - return new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.COUNT, null, alias); - } else { - // COUNT(column) - String columnName = args.get(0).getName(); - return new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.COUNT, columnName, alias); - } - } - - // SUM - if (funcDef == BuiltInFunctionDefinitions.SUM || funcDef == BuiltInFunctionDefinitions.SUM0) { - if (args.isEmpty()) { - return null; - } - String columnName = args.get(0).getName(); - return new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.SUM, columnName, alias); - } - - // AVG - if (funcDef == BuiltInFunctionDefinitions.AVG) { - if (args.isEmpty()) { - return null; - } - String columnName = args.get(0).getName(); - return new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.AVG, columnName, alias); - } - - // MIN - if (funcDef == BuiltInFunctionDefinitions.MIN) { - if (args.isEmpty()) { - return null; - } - String columnName = args.get(0).getName(); - return new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.MIN, columnName, alias); - } - - // MAX - if (funcDef == BuiltInFunctionDefinitions.MAX) { - if (args.isEmpty()) { - return null; - } - String columnName = args.get(0).getName(); - return new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.MAX, columnName, alias); - } - - // Unsupported aggregate function + // MIN + if (funcDef == BuiltInFunctionDefinitions.MIN) { + if (args.isEmpty()) { return null; + } + String columnName = args.get(0).getName(); + return new AggregateInfo.AggregateCall( + AggregateInfo.AggregateFunction.MIN, columnName, alias); } - /** - * Get aggregate info - */ - public AggregateInfo getAggregateInfo() { - return aggregateInfo; + // MAX + if (funcDef == BuiltInFunctionDefinitions.MAX) { + if (args.isEmpty()) { + return null; + } + String columnName = args.get(0).getName(); + return new AggregateInfo.AggregateCall( + AggregateInfo.AggregateFunction.MAX, columnName, alias); } - /** - * Whether aggregate push-down is accepted - */ - public boolean isAggregatePushDownAccepted() { - return aggregatePushDownAccepted; - } + // Unsupported aggregate function + return null; + } + + /** Get aggregate info */ + public AggregateInfo getAggregateInfo() { + return aggregateInfo; + } + + /** Whether aggregate push-down is accepted */ + public boolean isAggregatePushDownAccepted() { + return aggregatePushDownAccepted; + } } diff --git a/src/main/java/org/apache/flink/connector/lance/table/LanceVectorSearchFunction.java b/src/main/java/org/apache/flink/connector/lance/table/LanceVectorSearchFunction.java index 3107dce..de0c384 100644 --- a/src/main/java/org/apache/flink/connector/lance/table/LanceVectorSearchFunction.java +++ b/src/main/java/org/apache/flink/connector/lance/table/LanceVectorSearchFunction.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.table; import org.apache.flink.connector.lance.LanceVectorSearch; @@ -25,40 +20,34 @@ import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.data.ArrayData; -import org.apache.flink.table.data.GenericArrayData; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.StringData; import org.apache.flink.table.functions.FunctionContext; import org.apache.flink.table.functions.TableFunction; -import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.inference.TypeInference; import org.apache.flink.table.types.inference.TypeStrategies; -import org.apache.flink.table.types.logical.ArrayType; -import org.apache.flink.table.types.logical.DoubleType; -import org.apache.flink.table.types.logical.RowType; import org.apache.flink.types.Row; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; import java.math.BigDecimal; import java.util.List; -import java.util.Optional; /** * Lance vector search UDF. - * + * *

Implements TableFunction, supports executing vector search in SQL. - * + * *

Usage example: + * *

{@code
  * -- Register UDF
- * CREATE TEMPORARY FUNCTION vector_search AS 
+ * CREATE TEMPORARY FUNCTION vector_search AS
  *     'org.apache.flink.connector.lance.table.LanceVectorSearchFunction'
  *     LANGUAGE JAVA USING JAR '/path/to/flink-connector-lance.jar';
- * 
+ *
  * -- Use UDF for vector search
  * SELECT * FROM TABLE(
  *     vector_search('/path/to/dataset', 'embedding', ARRAY[0.1, 0.2, 0.3], 10, 'L2')
@@ -66,292 +55,275 @@
  * }
*/ @FunctionHint( - output = @DataTypeHint("ROW, _distance DOUBLE>") -) + output = + @DataTypeHint("ROW, _distance DOUBLE>")) public class LanceVectorSearchFunction extends TableFunction { - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(LanceVectorSearchFunction.class); + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(LanceVectorSearchFunction.class); + + private transient LanceVectorSearch vectorSearch; + private String currentDatasetPath; + private String currentColumnName; - private transient LanceVectorSearch vectorSearch; - private String currentDatasetPath; - private String currentColumnName; + @Override + public void open(FunctionContext context) throws Exception { + super.open(context); + LOG.info("Opening LanceVectorSearchFunction"); + } - @Override - public void open(FunctionContext context) throws Exception { - super.open(context); - LOG.info("Opening LanceVectorSearchFunction"); + @Override + public void close() throws Exception { + LOG.info("Closing LanceVectorSearchFunction"); + + if (vectorSearch != null) { + try { + vectorSearch.close(); + } catch (Exception e) { + LOG.warn("Failed to close vector searcher", e); + } + vectorSearch = null; } - @Override - public void close() throws Exception { - LOG.info("Closing LanceVectorSearchFunction"); - + super.close(); + } + + /** + * Execute vector search + * + * @param datasetPath Dataset path + * @param columnName Vector column name + * @param queryVector Query vector + * @param k Number of nearest neighbors to return + * @param metric Distance metric type: L2, Cosine, Dot + */ + public void eval( + String datasetPath, String columnName, Float[] queryVector, Integer k, String metric) { + try { + // Check if need to reinitialize vector searcher + if (vectorSearch == null + || !datasetPath.equals(currentDatasetPath) + || !columnName.equals(currentColumnName)) { + if (vectorSearch != null) { - try { - vectorSearch.close(); - } catch (Exception e) { - LOG.warn("Failed to close vector searcher", e); - } - vectorSearch = null; + vectorSearch.close(); } - - super.close(); - } - /** - * Execute vector search - * - * @param datasetPath Dataset path - * @param columnName Vector column name - * @param queryVector Query vector - * @param k Number of nearest neighbors to return - * @param metric Distance metric type: L2, Cosine, Dot - */ - public void eval(String datasetPath, String columnName, Float[] queryVector, Integer k, String metric) { - try { - // Check if need to reinitialize vector searcher - if (vectorSearch == null || - !datasetPath.equals(currentDatasetPath) || - !columnName.equals(currentColumnName)) { - - if (vectorSearch != null) { - vectorSearch.close(); - } - - LanceOptions.MetricType metricType = LanceOptions.MetricType.fromValue( - metric != null ? metric : "L2" - ); - - vectorSearch = LanceVectorSearch.builder() - .datasetPath(datasetPath) - .columnName(columnName) - .metricType(metricType) - .build(); - - vectorSearch.open(); - - currentDatasetPath = datasetPath; - currentColumnName = columnName; - } - - // Convert query vector - float[] query = new float[queryVector.length]; - for (int i = 0; i < queryVector.length; i++) { - query[i] = queryVector[i] != null ? queryVector[i] : 0.0f; - } - - // Execute search - int topK = k != null ? k : 10; - List results = vectorSearch.search(query, topK); - - // Output results - for (LanceVectorSearch.SearchResult result : results) { - RowData rowData = result.getRowData(); - double distance = result.getDistance(); - - // Build output Row - Row outputRow = convertToRow(rowData, distance); - if (outputRow != null) { - collect(outputRow); - } - } - - } catch (Exception e) { - LOG.error("Vector search failed", e); - throw new RuntimeException("Vector search failed: " + e.getMessage(), e); + LanceOptions.MetricType metricType = + LanceOptions.MetricType.fromValue(metric != null ? metric : "L2"); + + vectorSearch = + LanceVectorSearch.builder() + .datasetPath(datasetPath) + .columnName(columnName) + .metricType(metricType) + .build(); + + vectorSearch.open(); + + currentDatasetPath = datasetPath; + currentColumnName = columnName; + } + + // Convert query vector + float[] query = new float[queryVector.length]; + for (int i = 0; i < queryVector.length; i++) { + query[i] = queryVector[i] != null ? queryVector[i] : 0.0f; + } + + // Execute search + int topK = k != null ? k : 10; + List results = vectorSearch.search(query, topK); + + // Output results + for (LanceVectorSearch.SearchResult result : results) { + RowData rowData = result.getRowData(); + double distance = result.getDistance(); + + // Build output Row + Row outputRow = convertToRow(rowData, distance); + if (outputRow != null) { + collect(outputRow); } - } + } - /** - * Simplified vector search (using default parameters) - * - * @param datasetPath Dataset path - * @param columnName Vector column name - * @param queryVector Query vector - * @param k Number of nearest neighbors to return - */ - public void eval(String datasetPath, String columnName, Float[] queryVector, Integer k) { - eval(datasetPath, columnName, queryVector, k, "L2"); + } catch (Exception e) { + LOG.error("Vector search failed", e); + throw new RuntimeException("Vector search failed: " + e.getMessage(), e); } + } - /** - * Most simplified vector search - * - * @param datasetPath Dataset path - * @param columnName Vector column name - * @param queryVector Query vector - */ - public void eval(String datasetPath, String columnName, Float[] queryVector) { - eval(datasetPath, columnName, queryVector, 10, "L2"); - } + /** + * Simplified vector search (using default parameters) + * + * @param datasetPath Dataset path + * @param columnName Vector column name + * @param queryVector Query vector + * @param k Number of nearest neighbors to return + */ + public void eval(String datasetPath, String columnName, Float[] queryVector, Integer k) { + eval(datasetPath, columnName, queryVector, k, "L2"); + } - // ==================== BigDecimal[] parameter overloads ==================== - // ARRAY[0.1, 0.2, ...] in Flink SQL is parsed as BigDecimal[] type - - /** - * Execute vector search (supports BigDecimal[] parameter) - * - *

ARRAY[0.1, 0.2, ...] literals in Flink SQL are parsed as DECIMAL type arrays, - * so this method overload is needed for support. - * - * @param datasetPath Dataset path - * @param columnName Vector column name - * @param queryVector Query vector (BigDecimal array) - * @param k Number of nearest neighbors to return - * @param metric Distance metric type: L2, Cosine, Dot - */ - public void eval(String datasetPath, String columnName, BigDecimal[] queryVector, Integer k, String metric) { - Float[] floatVector = convertBigDecimalToFloat(queryVector); - eval(datasetPath, columnName, floatVector, k, metric); - } + /** + * Most simplified vector search + * + * @param datasetPath Dataset path + * @param columnName Vector column name + * @param queryVector Query vector + */ + public void eval(String datasetPath, String columnName, Float[] queryVector) { + eval(datasetPath, columnName, queryVector, 10, "L2"); + } - /** - * Simplified vector search (BigDecimal[] parameter) - */ - public void eval(String datasetPath, String columnName, BigDecimal[] queryVector, Integer k) { - eval(datasetPath, columnName, queryVector, k, "L2"); - } + // ==================== BigDecimal[] parameter overloads ==================== + // ARRAY[0.1, 0.2, ...] in Flink SQL is parsed as BigDecimal[] type - /** - * Most simplified vector search (BigDecimal[] parameter) - */ - public void eval(String datasetPath, String columnName, BigDecimal[] queryVector) { - eval(datasetPath, columnName, queryVector, 10, "L2"); - } + /** + * Execute vector search (supports BigDecimal[] parameter) + * + *

ARRAY[0.1, 0.2, ...] literals in Flink SQL are parsed as DECIMAL type arrays, so this method + * overload is needed for support. + * + * @param datasetPath Dataset path + * @param columnName Vector column name + * @param queryVector Query vector (BigDecimal array) + * @param k Number of nearest neighbors to return + * @param metric Distance metric type: L2, Cosine, Dot + */ + public void eval( + String datasetPath, String columnName, BigDecimal[] queryVector, Integer k, String metric) { + Float[] floatVector = convertBigDecimalToFloat(queryVector); + eval(datasetPath, columnName, floatVector, k, metric); + } - // ==================== Double[] parameter overloads ==================== - // In some cases parameters may be parsed as Double[] type + /** Simplified vector search (BigDecimal[] parameter) */ + public void eval(String datasetPath, String columnName, BigDecimal[] queryVector, Integer k) { + eval(datasetPath, columnName, queryVector, k, "L2"); + } - /** - * Execute vector search (supports Double[] parameter) - */ - public void eval(String datasetPath, String columnName, Double[] queryVector, Integer k, String metric) { - Float[] floatVector = convertDoubleToFloat(queryVector); - eval(datasetPath, columnName, floatVector, k, metric); - } + /** Most simplified vector search (BigDecimal[] parameter) */ + public void eval(String datasetPath, String columnName, BigDecimal[] queryVector) { + eval(datasetPath, columnName, queryVector, 10, "L2"); + } - /** - * Simplified vector search (Double[] parameter) - */ - public void eval(String datasetPath, String columnName, Double[] queryVector, Integer k) { - eval(datasetPath, columnName, queryVector, k, "L2"); + // ==================== Double[] parameter overloads ==================== + // In some cases parameters may be parsed as Double[] type + + /** Execute vector search (supports Double[] parameter) */ + public void eval( + String datasetPath, String columnName, Double[] queryVector, Integer k, String metric) { + Float[] floatVector = convertDoubleToFloat(queryVector); + eval(datasetPath, columnName, floatVector, k, metric); + } + + /** Simplified vector search (Double[] parameter) */ + public void eval(String datasetPath, String columnName, Double[] queryVector, Integer k) { + eval(datasetPath, columnName, queryVector, k, "L2"); + } + + /** Most simplified vector search (Double[] parameter) */ + public void eval(String datasetPath, String columnName, Double[] queryVector) { + eval(datasetPath, columnName, queryVector, 10, "L2"); + } + + // ==================== float[] primitive array parameter overloads ==================== + + /** Execute vector search (supports float[] primitive array parameter) */ + public void eval( + String datasetPath, String columnName, float[] queryVector, Integer k, String metric) { + Float[] floatVector = new Float[queryVector.length]; + for (int i = 0; i < queryVector.length; i++) { + floatVector[i] = queryVector[i]; } + eval(datasetPath, columnName, floatVector, k, metric); + } - /** - * Most simplified vector search (Double[] parameter) - */ - public void eval(String datasetPath, String columnName, Double[] queryVector) { - eval(datasetPath, columnName, queryVector, 10, "L2"); + /** Convert BigDecimal array to Float array */ + private Float[] convertBigDecimalToFloat(BigDecimal[] decimals) { + if (decimals == null) { + return new Float[0]; + } + Float[] result = new Float[decimals.length]; + for (int i = 0; i < decimals.length; i++) { + result[i] = decimals[i] != null ? decimals[i].floatValue() : 0.0f; } + return result; + } - // ==================== float[] primitive array parameter overloads ==================== + /** Convert Double array to Float array */ + private Float[] convertDoubleToFloat(Double[] doubles) { + if (doubles == null) { + return new Float[0]; + } + Float[] result = new Float[doubles.length]; + for (int i = 0; i < doubles.length; i++) { + result[i] = doubles[i] != null ? doubles[i].floatValue() : 0.0f; + } + return result; + } - /** - * Execute vector search (supports float[] primitive array parameter) - */ - public void eval(String datasetPath, String columnName, float[] queryVector, Integer k, String metric) { - Float[] floatVector = new Float[queryVector.length]; - for (int i = 0; i < queryVector.length; i++) { - floatVector[i] = queryVector[i]; - } - eval(datasetPath, columnName, floatVector, k, metric); + /** Convert RowData to Row */ + private Row convertToRow(RowData rowData, double distance) { + if (rowData == null) { + return null; } - /** - * Convert BigDecimal array to Float array - */ - private Float[] convertBigDecimalToFloat(BigDecimal[] decimals) { - if (decimals == null) { - return new Float[0]; - } - Float[] result = new Float[decimals.length]; - for (int i = 0; i < decimals.length; i++) { - result[i] = decimals[i] != null ? decimals[i].floatValue() : 0.0f; - } - return result; + if (rowData instanceof GenericRowData) { + GenericRowData genericRowData = (GenericRowData) rowData; + int arity = genericRowData.getArity(); + + // Create new Row including distance field + Object[] values = new Object[arity + 1]; + for (int i = 0; i < arity; i++) { + Object field = genericRowData.getField(i); + values[i] = convertField(field); + } + values[arity] = distance; + + return Row.of(values); } - /** - * Convert Double array to Float array - */ - private Float[] convertDoubleToFloat(Double[] doubles) { - if (doubles == null) { - return new Float[0]; - } - Float[] result = new Float[doubles.length]; - for (int i = 0; i < doubles.length; i++) { - result[i] = doubles[i] != null ? doubles[i].floatValue() : 0.0f; - } - return result; + return null; + } + + /** Convert field value */ + private Object convertField(Object field) { + if (field == null) { + return null; } - /** - * Convert RowData to Row - */ - private Row convertToRow(RowData rowData, double distance) { - if (rowData == null) { - return null; - } - - if (rowData instanceof GenericRowData) { - GenericRowData genericRowData = (GenericRowData) rowData; - int arity = genericRowData.getArity(); - - // Create new Row including distance field - Object[] values = new Object[arity + 1]; - for (int i = 0; i < arity; i++) { - Object field = genericRowData.getField(i); - values[i] = convertField(field); - } - values[arity] = distance; - - return Row.of(values); - } - - return null; + if (field instanceof StringData) { + return ((StringData) field).toString(); } - /** - * Convert field value - */ - private Object convertField(Object field) { - if (field == null) { - return null; + if (field instanceof ArrayData) { + ArrayData arrayData = (ArrayData) field; + int size = arrayData.size(); + Float[] result = new Float[size]; + for (int i = 0; i < size; i++) { + if (arrayData.isNullAt(i)) { + result[i] = null; + } else { + result[i] = arrayData.getFloat(i); } - - if (field instanceof StringData) { - return ((StringData) field).toString(); - } - - if (field instanceof ArrayData) { - ArrayData arrayData = (ArrayData) field; - int size = arrayData.size(); - Float[] result = new Float[size]; - for (int i = 0; i < size; i++) { - if (arrayData.isNullAt(i)) { - result[i] = null; - } else { - result[i] = arrayData.getFloat(i); - } - } - return result; - } - - return field; + } + return result; } - @Override - public TypeInference getTypeInference(DataTypeFactory typeFactory) { - return TypeInference.newBuilder() - .outputTypeStrategy(TypeStrategies.explicit( - DataTypes.ROW( - DataTypes.FIELD("id", DataTypes.BIGINT()), - DataTypes.FIELD("content", DataTypes.STRING()), - DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT())), - DataTypes.FIELD("_distance", DataTypes.DOUBLE()) - ) - )) - .build(); - } + return field; + } + + @Override + public TypeInference getTypeInference(DataTypeFactory typeFactory) { + return TypeInference.newBuilder() + .outputTypeStrategy( + TypeStrategies.explicit( + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.BIGINT()), + DataTypes.FIELD("content", DataTypes.STRING()), + DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT())), + DataTypes.FIELD("_distance", DataTypes.DOUBLE())))) + .build(); + } } diff --git a/src/test/java/org/apache/flink/connector/lance/LanceConnectorITCase.java b/src/test/java/org/apache/flink/connector/lance/LanceConnectorITCase.java index 9967ab3..221b0cd 100644 --- a/src/test/java/org/apache/flink/connector/lance/LanceConnectorITCase.java +++ b/src/test/java/org/apache/flink/connector/lance/LanceConnectorITCase.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.connector.lance.config.LanceOptions; @@ -52,364 +47,353 @@ import static org.assertj.core.api.Assertions.assertThat; -/** - * Lance Connector end-to-end integration tests. - */ +/** Lance Connector end-to-end integration tests. */ class LanceConnectorITCase { - @TempDir - Path tempDir; - - private String datasetPath; - private String warehousePath; - private RowType rowType; - private DataType dataType; - - @BeforeEach - void setUp() { - datasetPath = tempDir.resolve("test_e2e_dataset").toString(); - warehousePath = tempDir.resolve("test_e2e_warehouse").toString(); - - // Create test Schema - List fields = new ArrayList<>(); - fields.add(new RowType.RowField("id", new BigIntType())); - fields.add(new RowType.RowField("content", new VarCharType())); - fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); - rowType = new RowType(fields); - - dataType = DataTypes.ROW( - DataTypes.FIELD("id", DataTypes.BIGINT()), - DataTypes.FIELD("content", DataTypes.STRING()), - DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT())) - ); - } - - @Test - @DisplayName("Test complete configuration options workflow") - void testCompleteOptionsWorkflow() { - // Build complete configuration - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - // Source configuration - .readBatchSize(512) - .readColumns(Arrays.asList("id", "content", "embedding")) - .readFilter("id > 0") - // Sink configuration - .writeBatchSize(256) - .writeMode(WriteMode.APPEND) - .writeMaxRowsPerFile(100000) - // Index configuration - .indexType(IndexType.IVF_PQ) - .indexColumn("embedding") - .indexNumPartitions(128) - .indexNumSubVectors(16) - .indexNumBits(8) - // Vector search configuration - .vectorColumn("embedding") - .vectorMetric(MetricType.L2) - .vectorNprobes(20) - .vectorEf(100) - // Catalog configuration - .defaultDatabase("default") - .warehouse(warehousePath) - .build(); - - // Verify all configurations - assertThat(options.getPath()).isEqualTo(datasetPath); - assertThat(options.getReadBatchSize()).isEqualTo(512); - assertThat(options.getReadColumns()).containsExactly("id", "content", "embedding"); - assertThat(options.getReadFilter()).isEqualTo("id > 0"); - assertThat(options.getWriteBatchSize()).isEqualTo(256); - assertThat(options.getWriteMode()).isEqualTo(WriteMode.APPEND); - assertThat(options.getWriteMaxRowsPerFile()).isEqualTo(100000); - assertThat(options.getIndexType()).isEqualTo(IndexType.IVF_PQ); - assertThat(options.getIndexColumn()).isEqualTo("embedding"); - assertThat(options.getIndexNumPartitions()).isEqualTo(128); - assertThat(options.getIndexNumSubVectors()).isEqualTo(16); - assertThat(options.getIndexNumBits()).isEqualTo(8); - assertThat(options.getVectorColumn()).isEqualTo("embedding"); - assertThat(options.getVectorMetric()).isEqualTo(MetricType.L2); - assertThat(options.getVectorNprobes()).isEqualTo(20); - assertThat(options.getVectorEf()).isEqualTo(100); - assertThat(options.getDefaultDatabase()).isEqualTo("default"); - assertThat(options.getWarehouse()).isEqualTo(warehousePath); - } - - @Test - @DisplayName("Test RowDataConverter data conversion workflow") - void testRowDataConverterWorkflow() { - RowDataConverter converter = new RowDataConverter(rowType); - - // Create test data - List testData = new ArrayList<>(); - for (int i = 0; i < 10; i++) { - GenericRowData row = new GenericRowData(3); - row.setField(0, (long) i); - row.setField(1, StringData.fromString("Content " + i)); - - // Create vector data - Float[] vector = new Float[128]; - for (int j = 0; j < 128; j++) { - vector[j] = (float) (i * 0.1 + j * 0.01); - } - row.setField(2, new GenericArrayData(vector)); - - testData.add(row); - } - - // Verify data creation succeeded - assertThat(testData).hasSize(10); - assertThat(converter.getRowType()).isEqualTo(rowType); - assertThat(converter.getFieldNames()).containsExactly("id", "content", "embedding"); - } - - @Test - @DisplayName("Test LanceSource builder pattern") - void testLanceSourceBuilder() { - LanceSource source = LanceSource.builder() - .path(datasetPath) - .batchSize(256) - .columns(Arrays.asList("id", "embedding")) - .filter("id < 1000") - .rowType(rowType) - .build(); - - assertThat(source.getOptions().getPath()).isEqualTo(datasetPath); - assertThat(source.getOptions().getReadBatchSize()).isEqualTo(256); - assertThat(source.getSelectedColumns()).containsExactly("id", "embedding"); - assertThat(source.getRowType()).isEqualTo(rowType); - } - - @Test - @DisplayName("Test LanceSink builder pattern") - void testLanceSinkBuilder() { - LanceSink sink = LanceSink.builder() - .path(datasetPath) - .batchSize(128) - .writeMode(WriteMode.OVERWRITE) - .maxRowsPerFile(50000) - .rowType(rowType) - .build(); - - assertThat(sink.getOptions().getPath()).isEqualTo(datasetPath); - assertThat(sink.getOptions().getWriteBatchSize()).isEqualTo(128); - assertThat(sink.getOptions().getWriteMode()).isEqualTo(WriteMode.OVERWRITE); - assertThat(sink.getOptions().getWriteMaxRowsPerFile()).isEqualTo(50000); - assertThat(sink.getRowType()).isEqualTo(rowType); - } - - @Test - @DisplayName("Test LanceIndexBuilder builder pattern") - void testLanceIndexBuilder() { - LanceIndexBuilder builder = LanceIndexBuilder.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .indexType(IndexType.IVF_HNSW) - .metricType(MetricType.COSINE) - .numPartitions(64) - .maxLevel(5) - .m(24) - .efConstruction(200) - .replace(true) - .build(); - - assertThat(builder).isNotNull(); - } - - @Test - @DisplayName("Test LanceVectorSearch builder pattern") - void testLanceVectorSearchBuilder() { - LanceVectorSearch search = LanceVectorSearch.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .metricType(MetricType.DOT) - .nprobes(30) - .ef(150) - .refineFactor(5) - .build(); - - assertThat(search).isNotNull(); - } - - @Test - @DisplayName("Test Table API component creation") - void testTableApiComponents() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .build(); - - // Create DynamicTableSource - LanceDynamicTableSource source = new LanceDynamicTableSource(options, dataType); - assertThat(source.asSummaryString()).isEqualTo("Lance Table Source"); - - // Create DynamicTableSink - LanceDynamicTableSink sink = new LanceDynamicTableSink(options, dataType); - assertThat(sink.asSummaryString()).isEqualTo("Lance Table Sink"); - - // Create Factory - LanceDynamicTableFactory factory = new LanceDynamicTableFactory(); - assertThat(factory.factoryIdentifier()).isEqualTo("lance"); - } - - @Test - @DisplayName("Test Catalog lifecycle") - void testCatalogLifecycle() throws Exception { - LanceCatalog catalog = new LanceCatalog("test_catalog", "default", warehousePath); - - // Open Catalog - catalog.open(); - assertThat(catalog.getDefaultDatabase()).isEqualTo("default"); - assertThat(catalog.getWarehouse()).isEqualTo(warehousePath); - - // Verify default database exists - assertThat(catalog.databaseExists("default")).isTrue(); - - // Create test database - catalog.createDatabase("test_db", null, true); - assertThat(catalog.databaseExists("test_db")).isTrue(); - assertThat(catalog.listDatabases()).contains("default", "test_db"); - - // List empty tables - assertThat(catalog.listTables("test_db")).isEmpty(); - - // Drop test database - catalog.dropDatabase("test_db", true, true); - assertThat(catalog.databaseExists("test_db")).isFalse(); - - // Close Catalog - catalog.close(); - } - - @Test - @DisplayName("Test type conversion bidirectional consistency") - void testTypeConversionConsistency() { - // Flink RowType -> Arrow Schema -> Flink RowType - org.apache.arrow.vector.types.pojo.Schema arrowSchema = - LanceTypeConverter.toArrowSchema(rowType); - RowType convertedRowType = LanceTypeConverter.toFlinkRowType(arrowSchema); - - // Verify field count - assertThat(convertedRowType.getFieldCount()).isEqualTo(rowType.getFieldCount()); - - // Verify field names - assertThat(convertedRowType.getFieldNames()).isEqualTo(rowType.getFieldNames()); - } - - @Test - @DisplayName("Test vector data conversion") - void testVectorDataConversion() { - // Create float array - float[] originalVector = new float[] {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; - - // Convert to ArrayData - org.apache.flink.table.data.ArrayData arrayData = - RowDataConverter.toArrayData(originalVector); - - // Convert back to float array - float[] convertedVector = RowDataConverter.toFloatArray(arrayData); - - // Verify consistency - assertThat(convertedVector).containsExactly(originalVector); + @TempDir Path tempDir; + + private String datasetPath; + private String warehousePath; + private RowType rowType; + private DataType dataType; + + @BeforeEach + void setUp() { + datasetPath = tempDir.resolve("test_e2e_dataset").toString(); + warehousePath = tempDir.resolve("test_e2e_warehouse").toString(); + + // Create test Schema + List fields = new ArrayList<>(); + fields.add(new RowType.RowField("id", new BigIntType())); + fields.add(new RowType.RowField("content", new VarCharType())); + fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); + rowType = new RowType(fields); + + dataType = + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.BIGINT()), + DataTypes.FIELD("content", DataTypes.STRING()), + DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + } + + @Test + @DisplayName("Test complete configuration options workflow") + void testCompleteOptionsWorkflow() { + // Build complete configuration + LanceOptions options = + LanceOptions.builder() + .path(datasetPath) + // Source configuration + .readBatchSize(512) + .readColumns(Arrays.asList("id", "content", "embedding")) + .readFilter("id > 0") + // Sink configuration + .writeBatchSize(256) + .writeMode(WriteMode.APPEND) + .writeMaxRowsPerFile(100000) + // Index configuration + .indexType(IndexType.IVF_PQ) + .indexColumn("embedding") + .indexNumPartitions(128) + .indexNumSubVectors(16) + .indexNumBits(8) + // Vector search configuration + .vectorColumn("embedding") + .vectorMetric(MetricType.L2) + .vectorNprobes(20) + .vectorEf(100) + // Catalog configuration + .defaultDatabase("default") + .warehouse(warehousePath) + .build(); + + // Verify all configurations + assertThat(options.getPath()).isEqualTo(datasetPath); + assertThat(options.getReadBatchSize()).isEqualTo(512); + assertThat(options.getReadColumns()).containsExactly("id", "content", "embedding"); + assertThat(options.getReadFilter()).isEqualTo("id > 0"); + assertThat(options.getWriteBatchSize()).isEqualTo(256); + assertThat(options.getWriteMode()).isEqualTo(WriteMode.APPEND); + assertThat(options.getWriteMaxRowsPerFile()).isEqualTo(100000); + assertThat(options.getIndexType()).isEqualTo(IndexType.IVF_PQ); + assertThat(options.getIndexColumn()).isEqualTo("embedding"); + assertThat(options.getIndexNumPartitions()).isEqualTo(128); + assertThat(options.getIndexNumSubVectors()).isEqualTo(16); + assertThat(options.getIndexNumBits()).isEqualTo(8); + assertThat(options.getVectorColumn()).isEqualTo("embedding"); + assertThat(options.getVectorMetric()).isEqualTo(MetricType.L2); + assertThat(options.getVectorNprobes()).isEqualTo(20); + assertThat(options.getVectorEf()).isEqualTo(100); + assertThat(options.getDefaultDatabase()).isEqualTo("default"); + assertThat(options.getWarehouse()).isEqualTo(warehousePath); + } + + @Test + @DisplayName("Test RowDataConverter data conversion workflow") + void testRowDataConverterWorkflow() { + RowDataConverter converter = new RowDataConverter(rowType); + + // Create test data + List testData = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + GenericRowData row = new GenericRowData(3); + row.setField(0, (long) i); + row.setField(1, StringData.fromString("Content " + i)); + + // Create vector data + Float[] vector = new Float[128]; + for (int j = 0; j < 128; j++) { + vector[j] = (float) (i * 0.1 + j * 0.01); + } + row.setField(2, new GenericArrayData(vector)); + + testData.add(row); } - @Test - @DisplayName("Test double vector data conversion") - void testDoubleVectorDataConversion() { - // Create double array - double[] originalVector = new double[] {0.1, 0.2, 0.3, 0.4, 0.5}; - - // Convert to ArrayData - org.apache.flink.table.data.ArrayData arrayData = - RowDataConverter.toArrayData(originalVector); - - // Convert back to double array - double[] convertedVector = RowDataConverter.toDoubleArray(arrayData); - - // Verify consistency - assertThat(convertedVector).containsExactly(originalVector); - } - - @Test - @DisplayName("Test LanceSplit serialization compatibility") - void testLanceSplitSerialization() { - LanceSplit split1 = new LanceSplit(0, 1, datasetPath, 10000); - LanceSplit split2 = new LanceSplit(0, 1, datasetPath, 10000); - LanceSplit split3 = new LanceSplit(1, 2, datasetPath, 20000); - - // Equality test - assertThat(split1).isEqualTo(split2); - assertThat(split1.hashCode()).isEqualTo(split2.hashCode()); - assertThat(split1).isNotEqualTo(split3); - - // toString test - String str = split1.toString(); - assertThat(str).contains("LanceSplit"); - assertThat(str).contains("fragmentId=1"); - assertThat(str).contains("rowCount=10000"); - } - - @Test - @DisplayName("Test search result similarity calculation") - void testSearchResultSimilarityCalculation() { - // Perfect match (distance=0) - LanceVectorSearch.SearchResult perfectMatch = - new LanceVectorSearch.SearchResult(null, 0.0); - assertThat(perfectMatch.getSimilarity()).isEqualTo(1.0); - - // Normal match (distance=1) - LanceVectorSearch.SearchResult normalMatch = - new LanceVectorSearch.SearchResult(null, 1.0); - assertThat(normalMatch.getSimilarity()).isEqualTo(0.5); - - // Far match (distance=9) - LanceVectorSearch.SearchResult farMatch = - new LanceVectorSearch.SearchResult(null, 9.0); - assertThat(farMatch.getSimilarity()).isEqualTo(0.1); - } - - @Test - @DisplayName("Test options toString and hashCode") - void testOptionsToStringAndHashCode() { - LanceOptions options1 = LanceOptions.builder() - .path(datasetPath) - .readBatchSize(512) - .build(); - - LanceOptions options2 = LanceOptions.builder() - .path(datasetPath) - .readBatchSize(512) - .build(); - - // hashCode equals - assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); - - // equals - assertThat(options1).isEqualTo(options2); - - // toString contains key info - String str = options1.toString(); - assertThat(str).contains("LanceOptions"); - assertThat(str).contains("readBatchSize=512"); - } - - @Test - @DisplayName("Test all enum types") - void testAllEnumTypes() { - // WriteMode - assertThat(WriteMode.values()).hasSize(2); - assertThat(WriteMode.APPEND.getValue()).isEqualTo("append"); - assertThat(WriteMode.OVERWRITE.getValue()).isEqualTo("overwrite"); - - // IndexType - assertThat(IndexType.values()).hasSize(3); - assertThat(IndexType.IVF_PQ.getValue()).isEqualTo("IVF_PQ"); - assertThat(IndexType.IVF_HNSW.getValue()).isEqualTo("IVF_HNSW"); - assertThat(IndexType.IVF_FLAT.getValue()).isEqualTo("IVF_FLAT"); - - // MetricType - assertThat(MetricType.values()).hasSize(3); - assertThat(MetricType.L2.getValue()).isEqualTo("L2"); - assertThat(MetricType.COSINE.getValue()).isEqualTo("Cosine"); - assertThat(MetricType.DOT.getValue()).isEqualTo("Dot"); - } + // Verify data creation succeeded + assertThat(testData).hasSize(10); + assertThat(converter.getRowType()).isEqualTo(rowType); + assertThat(converter.getFieldNames()).containsExactly("id", "content", "embedding"); + } + + @Test + @DisplayName("Test LanceSource builder pattern") + void testLanceSourceBuilder() { + LanceSource source = + LanceSource.builder() + .path(datasetPath) + .batchSize(256) + .columns(Arrays.asList("id", "embedding")) + .filter("id < 1000") + .rowType(rowType) + .build(); + + assertThat(source.getOptions().getPath()).isEqualTo(datasetPath); + assertThat(source.getOptions().getReadBatchSize()).isEqualTo(256); + assertThat(source.getSelectedColumns()).containsExactly("id", "embedding"); + assertThat(source.getRowType()).isEqualTo(rowType); + } + + @Test + @DisplayName("Test LanceSink builder pattern") + void testLanceSinkBuilder() { + LanceSink sink = + LanceSink.builder() + .path(datasetPath) + .batchSize(128) + .writeMode(WriteMode.OVERWRITE) + .maxRowsPerFile(50000) + .rowType(rowType) + .build(); + + assertThat(sink.getOptions().getPath()).isEqualTo(datasetPath); + assertThat(sink.getOptions().getWriteBatchSize()).isEqualTo(128); + assertThat(sink.getOptions().getWriteMode()).isEqualTo(WriteMode.OVERWRITE); + assertThat(sink.getOptions().getWriteMaxRowsPerFile()).isEqualTo(50000); + assertThat(sink.getRowType()).isEqualTo(rowType); + } + + @Test + @DisplayName("Test LanceIndexBuilder builder pattern") + void testLanceIndexBuilder() { + LanceIndexBuilder builder = + LanceIndexBuilder.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .indexType(IndexType.IVF_HNSW) + .metricType(MetricType.COSINE) + .numPartitions(64) + .maxLevel(5) + .m(24) + .efConstruction(200) + .replace(true) + .build(); + + assertThat(builder).isNotNull(); + } + + @Test + @DisplayName("Test LanceVectorSearch builder pattern") + void testLanceVectorSearchBuilder() { + LanceVectorSearch search = + LanceVectorSearch.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .metricType(MetricType.DOT) + .nprobes(30) + .ef(150) + .refineFactor(5) + .build(); + + assertThat(search).isNotNull(); + } + + @Test + @DisplayName("Test Table API component creation") + void testTableApiComponents() { + LanceOptions options = LanceOptions.builder().path(datasetPath).build(); + + // Create DynamicTableSource + LanceDynamicTableSource source = new LanceDynamicTableSource(options, dataType); + assertThat(source.asSummaryString()).isEqualTo("Lance Table Source"); + + // Create DynamicTableSink + LanceDynamicTableSink sink = new LanceDynamicTableSink(options, dataType); + assertThat(sink.asSummaryString()).isEqualTo("Lance Table Sink"); + + // Create Factory + LanceDynamicTableFactory factory = new LanceDynamicTableFactory(); + assertThat(factory.factoryIdentifier()).isEqualTo("lance"); + } + + @Test + @DisplayName("Test Catalog lifecycle") + void testCatalogLifecycle() throws Exception { + LanceCatalog catalog = new LanceCatalog("test_catalog", "default", warehousePath); + + // Open Catalog + catalog.open(); + assertThat(catalog.getDefaultDatabase()).isEqualTo("default"); + assertThat(catalog.getWarehouse()).isEqualTo(warehousePath); + + // Verify default database exists + assertThat(catalog.databaseExists("default")).isTrue(); + + // Create test database + catalog.createDatabase("test_db", null, true); + assertThat(catalog.databaseExists("test_db")).isTrue(); + assertThat(catalog.listDatabases()).contains("default", "test_db"); + + // List empty tables + assertThat(catalog.listTables("test_db")).isEmpty(); + + // Drop test database + catalog.dropDatabase("test_db", true, true); + assertThat(catalog.databaseExists("test_db")).isFalse(); + + // Close Catalog + catalog.close(); + } + + @Test + @DisplayName("Test type conversion bidirectional consistency") + void testTypeConversionConsistency() { + // Flink RowType -> Arrow Schema -> Flink RowType + org.apache.arrow.vector.types.pojo.Schema arrowSchema = + LanceTypeConverter.toArrowSchema(rowType); + RowType convertedRowType = LanceTypeConverter.toFlinkRowType(arrowSchema); + + // Verify field count + assertThat(convertedRowType.getFieldCount()).isEqualTo(rowType.getFieldCount()); + + // Verify field names + assertThat(convertedRowType.getFieldNames()).isEqualTo(rowType.getFieldNames()); + } + + @Test + @DisplayName("Test vector data conversion") + void testVectorDataConversion() { + // Create float array + float[] originalVector = new float[] {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; + + // Convert to ArrayData + org.apache.flink.table.data.ArrayData arrayData = RowDataConverter.toArrayData(originalVector); + + // Convert back to float array + float[] convertedVector = RowDataConverter.toFloatArray(arrayData); + + // Verify consistency + assertThat(convertedVector).containsExactly(originalVector); + } + + @Test + @DisplayName("Test double vector data conversion") + void testDoubleVectorDataConversion() { + // Create double array + double[] originalVector = new double[] {0.1, 0.2, 0.3, 0.4, 0.5}; + + // Convert to ArrayData + org.apache.flink.table.data.ArrayData arrayData = RowDataConverter.toArrayData(originalVector); + + // Convert back to double array + double[] convertedVector = RowDataConverter.toDoubleArray(arrayData); + + // Verify consistency + assertThat(convertedVector).containsExactly(originalVector); + } + + @Test + @DisplayName("Test LanceSplit serialization compatibility") + void testLanceSplitSerialization() { + LanceSplit split1 = new LanceSplit(0, 1, datasetPath, 10000); + LanceSplit split2 = new LanceSplit(0, 1, datasetPath, 10000); + LanceSplit split3 = new LanceSplit(1, 2, datasetPath, 20000); + + // Equality test + assertThat(split1).isEqualTo(split2); + assertThat(split1.hashCode()).isEqualTo(split2.hashCode()); + assertThat(split1).isNotEqualTo(split3); + + // toString test + String str = split1.toString(); + assertThat(str).contains("LanceSplit"); + assertThat(str).contains("fragmentId=1"); + assertThat(str).contains("rowCount=10000"); + } + + @Test + @DisplayName("Test search result similarity calculation") + void testSearchResultSimilarityCalculation() { + // Perfect match (distance=0) + LanceVectorSearch.SearchResult perfectMatch = new LanceVectorSearch.SearchResult(null, 0.0); + assertThat(perfectMatch.getSimilarity()).isEqualTo(1.0); + + // Normal match (distance=1) + LanceVectorSearch.SearchResult normalMatch = new LanceVectorSearch.SearchResult(null, 1.0); + assertThat(normalMatch.getSimilarity()).isEqualTo(0.5); + + // Far match (distance=9) + LanceVectorSearch.SearchResult farMatch = new LanceVectorSearch.SearchResult(null, 9.0); + assertThat(farMatch.getSimilarity()).isEqualTo(0.1); + } + + @Test + @DisplayName("Test options toString and hashCode") + void testOptionsToStringAndHashCode() { + LanceOptions options1 = LanceOptions.builder().path(datasetPath).readBatchSize(512).build(); + + LanceOptions options2 = LanceOptions.builder().path(datasetPath).readBatchSize(512).build(); + + // hashCode equals + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + + // equals + assertThat(options1).isEqualTo(options2); + + // toString contains key info + String str = options1.toString(); + assertThat(str).contains("LanceOptions"); + assertThat(str).contains("readBatchSize=512"); + } + + @Test + @DisplayName("Test all enum types") + void testAllEnumTypes() { + // WriteMode + assertThat(WriteMode.values()).hasSize(2); + assertThat(WriteMode.APPEND.getValue()).isEqualTo("append"); + assertThat(WriteMode.OVERWRITE.getValue()).isEqualTo("overwrite"); + + // IndexType + assertThat(IndexType.values()).hasSize(3); + assertThat(IndexType.IVF_PQ.getValue()).isEqualTo("IVF_PQ"); + assertThat(IndexType.IVF_HNSW.getValue()).isEqualTo("IVF_HNSW"); + assertThat(IndexType.IVF_FLAT.getValue()).isEqualTo("IVF_FLAT"); + + // MetricType + assertThat(MetricType.values()).hasSize(3); + assertThat(MetricType.L2.getValue()).isEqualTo("L2"); + assertThat(MetricType.COSINE.getValue()).isEqualTo("Cosine"); + assertThat(MetricType.DOT.getValue()).isEqualTo("Dot"); + } } diff --git a/src/test/java/org/apache/flink/connector/lance/LanceIndexBuilderTest.java b/src/test/java/org/apache/flink/connector/lance/LanceIndexBuilderTest.java index 81972d2..4df842d 100644 --- a/src/test/java/org/apache/flink/connector/lance/LanceIndexBuilderTest.java +++ b/src/test/java/org/apache/flink/connector/lance/LanceIndexBuilderTest.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.connector.lance.config.LanceOptions; @@ -32,258 +27,260 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -/** - * LanceIndexBuilder unit tests. - */ +/** LanceIndexBuilder unit tests. */ class LanceIndexBuilderTest { - @TempDir - Path tempDir; - - private String datasetPath; - - @BeforeEach - void setUp() { - datasetPath = tempDir.resolve("test_index_dataset").toString(); - } - - @Test - @DisplayName("Test IVF_PQ index configuration build") - void testIvfPqIndexConfiguration() { - LanceIndexBuilder builder = LanceIndexBuilder.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .indexType(IndexType.IVF_PQ) - .numPartitions(128) - .numSubVectors(16) - .numBits(8) - .metricType(MetricType.L2) - .build(); - - // Verify configuration - by successful build - assertThat(builder).isNotNull(); - } - - @Test - @DisplayName("Test IVF_HNSW index configuration build") - void testIvfHnswIndexConfiguration() { - LanceIndexBuilder builder = LanceIndexBuilder.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .indexType(IndexType.IVF_HNSW) - .numPartitions(64) - .maxLevel(5) - .m(24) - .efConstruction(200) - .metricType(MetricType.COSINE) - .build(); - - assertThat(builder).isNotNull(); - } - - @Test - @DisplayName("Test IVF_FLAT index configuration build") - void testIvfFlatIndexConfiguration() { - LanceIndexBuilder builder = LanceIndexBuilder.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .indexType(IndexType.IVF_FLAT) - .numPartitions(256) - .metricType(MetricType.DOT) - .build(); - - assertThat(builder).isNotNull(); - } - - @Test - @DisplayName("Test index type enum") - void testIndexTypeEnum() { - assertThat(IndexType.fromValue("IVF_PQ")).isEqualTo(IndexType.IVF_PQ); - assertThat(IndexType.fromValue("ivf_pq")).isEqualTo(IndexType.IVF_PQ); - assertThat(IndexType.fromValue("IVF_HNSW")).isEqualTo(IndexType.IVF_HNSW); - assertThat(IndexType.fromValue("IVF_FLAT")).isEqualTo(IndexType.IVF_FLAT); - - assertThat(IndexType.IVF_PQ.getValue()).isEqualTo("IVF_PQ"); - assertThat(IndexType.IVF_HNSW.getValue()).isEqualTo("IVF_HNSW"); - assertThat(IndexType.IVF_FLAT.getValue()).isEqualTo("IVF_FLAT"); - } - - @Test - @DisplayName("Test invalid index type") - void testInvalidIndexType() { - assertThatThrownBy(() -> IndexType.fromValue("INVALID")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Unsupported index type"); - } - - @Test - @DisplayName("Test metric type enum") - void testMetricTypeEnum() { - assertThat(MetricType.fromValue("L2")).isEqualTo(MetricType.L2); - assertThat(MetricType.fromValue("l2")).isEqualTo(MetricType.L2); - assertThat(MetricType.fromValue("Cosine")).isEqualTo(MetricType.COSINE); - assertThat(MetricType.fromValue("cosine")).isEqualTo(MetricType.COSINE); - assertThat(MetricType.fromValue("Dot")).isEqualTo(MetricType.DOT); - assertThat(MetricType.fromValue("dot")).isEqualTo(MetricType.DOT); - - assertThat(MetricType.L2.getValue()).isEqualTo("L2"); - assertThat(MetricType.COSINE.getValue()).isEqualTo("Cosine"); - assertThat(MetricType.DOT.getValue()).isEqualTo("Dot"); - } - - @Test - @DisplayName("Test invalid metric type") - void testInvalidMetricType() { - assertThatThrownBy(() -> MetricType.fromValue("INVALID")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Unsupported metric type"); - } - - @Test - @DisplayName("Test exception when missing dataset path") - void testMissingDatasetPath() { - assertThatThrownBy(() -> LanceIndexBuilder.builder() - .columnName("embedding") - .indexType(IndexType.IVF_PQ) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Dataset path cannot be empty"); - } - - @Test - @DisplayName("Test exception when missing column name") - void testMissingColumnName() { - assertThatThrownBy(() -> LanceIndexBuilder.builder() - .datasetPath(datasetPath) - .indexType(IndexType.IVF_PQ) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Column name cannot be empty"); - } - - @Test - @DisplayName("Test invalid number of partitions") - void testInvalidNumPartitions() { - assertThatThrownBy(() -> LanceIndexBuilder.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .numPartitions(0) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Number of partitions must be greater than 0"); - } - - @Test - @DisplayName("Test invalid number of sub-vectors") - void testInvalidNumSubVectors() { - assertThatThrownBy(() -> LanceIndexBuilder.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .numSubVectors(-1) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Number of sub-vectors must be greater than 0"); - } - - @Test - @DisplayName("Test invalid number of quantization bits") - void testInvalidNumBits() { - assertThatThrownBy(() -> LanceIndexBuilder.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .numBits(0) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Quantization bits must be between 1 and 16"); - - assertThatThrownBy(() -> LanceIndexBuilder.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .numBits(17) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Quantization bits must be between 1 and 16"); - } - - @Test - @DisplayName("Test default index configuration values") - void testDefaultIndexConfiguration() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .indexColumn("embedding") - .build(); - - // Verify default values - assertThat(options.getIndexType()).isEqualTo(IndexType.IVF_PQ); - assertThat(options.getIndexNumPartitions()).isEqualTo(256); - assertThat(options.getIndexNumBits()).isEqualTo(8); - assertThat(options.getIndexMaxLevel()).isEqualTo(7); - assertThat(options.getIndexM()).isEqualTo(16); - assertThat(options.getIndexEfConstruction()).isEqualTo(100); - } - - @Test - @DisplayName("Test creating index builder from LanceOptions") - void testFromOptions() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .indexColumn("embedding") - .indexType(IndexType.IVF_HNSW) - .indexNumPartitions(64) - .vectorMetric(MetricType.COSINE) - .build(); - - LanceIndexBuilder builder = LanceIndexBuilder.fromOptions(options); - - assertThat(builder).isNotNull(); - } - - @Test - @DisplayName("Test index build result") - void testIndexBuildResult() { - LanceIndexBuilder.IndexBuildResult result = new LanceIndexBuilder.IndexBuildResult( - true, - IndexType.IVF_PQ, - "embedding", - datasetPath, - 1000, - null - ); - - assertThat(result.isSuccess()).isTrue(); - assertThat(result.getIndexType()).isEqualTo(IndexType.IVF_PQ); - assertThat(result.getColumnName()).isEqualTo("embedding"); - assertThat(result.getDatasetPath()).isEqualTo(datasetPath); - assertThat(result.getDurationMillis()).isEqualTo(1000); - assertThat(result.getErrorMessage()).isNull(); - } - - @Test - @DisplayName("Test index build failure result") - void testIndexBuildFailureResult() { - LanceIndexBuilder.IndexBuildResult result = new LanceIndexBuilder.IndexBuildResult( - false, - IndexType.IVF_PQ, - "embedding", - datasetPath, - 500, - "Column does not exist" - ); - - assertThat(result.isSuccess()).isFalse(); - assertThat(result.getErrorMessage()).isEqualTo("Column does not exist"); - } - - @Test - @DisplayName("Test replace index option") - void testReplaceIndexOption() { - LanceIndexBuilder builder = LanceIndexBuilder.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .indexType(IndexType.IVF_PQ) - .replace(true) - .build(); - - assertThat(builder).isNotNull(); - } + @TempDir Path tempDir; + + private String datasetPath; + + @BeforeEach + void setUp() { + datasetPath = tempDir.resolve("test_index_dataset").toString(); + } + + @Test + @DisplayName("Test IVF_PQ index configuration build") + void testIvfPqIndexConfiguration() { + LanceIndexBuilder builder = + LanceIndexBuilder.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .indexType(IndexType.IVF_PQ) + .numPartitions(128) + .numSubVectors(16) + .numBits(8) + .metricType(MetricType.L2) + .build(); + + // Verify configuration - by successful build + assertThat(builder).isNotNull(); + } + + @Test + @DisplayName("Test IVF_HNSW index configuration build") + void testIvfHnswIndexConfiguration() { + LanceIndexBuilder builder = + LanceIndexBuilder.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .indexType(IndexType.IVF_HNSW) + .numPartitions(64) + .maxLevel(5) + .m(24) + .efConstruction(200) + .metricType(MetricType.COSINE) + .build(); + + assertThat(builder).isNotNull(); + } + + @Test + @DisplayName("Test IVF_FLAT index configuration build") + void testIvfFlatIndexConfiguration() { + LanceIndexBuilder builder = + LanceIndexBuilder.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .indexType(IndexType.IVF_FLAT) + .numPartitions(256) + .metricType(MetricType.DOT) + .build(); + + assertThat(builder).isNotNull(); + } + + @Test + @DisplayName("Test index type enum") + void testIndexTypeEnum() { + assertThat(IndexType.fromValue("IVF_PQ")).isEqualTo(IndexType.IVF_PQ); + assertThat(IndexType.fromValue("ivf_pq")).isEqualTo(IndexType.IVF_PQ); + assertThat(IndexType.fromValue("IVF_HNSW")).isEqualTo(IndexType.IVF_HNSW); + assertThat(IndexType.fromValue("IVF_FLAT")).isEqualTo(IndexType.IVF_FLAT); + + assertThat(IndexType.IVF_PQ.getValue()).isEqualTo("IVF_PQ"); + assertThat(IndexType.IVF_HNSW.getValue()).isEqualTo("IVF_HNSW"); + assertThat(IndexType.IVF_FLAT.getValue()).isEqualTo("IVF_FLAT"); + } + + @Test + @DisplayName("Test invalid index type") + void testInvalidIndexType() { + assertThatThrownBy(() -> IndexType.fromValue("INVALID")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unsupported index type"); + } + + @Test + @DisplayName("Test metric type enum") + void testMetricTypeEnum() { + assertThat(MetricType.fromValue("L2")).isEqualTo(MetricType.L2); + assertThat(MetricType.fromValue("l2")).isEqualTo(MetricType.L2); + assertThat(MetricType.fromValue("Cosine")).isEqualTo(MetricType.COSINE); + assertThat(MetricType.fromValue("cosine")).isEqualTo(MetricType.COSINE); + assertThat(MetricType.fromValue("Dot")).isEqualTo(MetricType.DOT); + assertThat(MetricType.fromValue("dot")).isEqualTo(MetricType.DOT); + + assertThat(MetricType.L2.getValue()).isEqualTo("L2"); + assertThat(MetricType.COSINE.getValue()).isEqualTo("Cosine"); + assertThat(MetricType.DOT.getValue()).isEqualTo("Dot"); + } + + @Test + @DisplayName("Test invalid metric type") + void testInvalidMetricType() { + assertThatThrownBy(() -> MetricType.fromValue("INVALID")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unsupported metric type"); + } + + @Test + @DisplayName("Test exception when missing dataset path") + void testMissingDatasetPath() { + assertThatThrownBy( + () -> + LanceIndexBuilder.builder() + .columnName("embedding") + .indexType(IndexType.IVF_PQ) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Dataset path cannot be empty"); + } + + @Test + @DisplayName("Test exception when missing column name") + void testMissingColumnName() { + assertThatThrownBy( + () -> + LanceIndexBuilder.builder() + .datasetPath(datasetPath) + .indexType(IndexType.IVF_PQ) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Column name cannot be empty"); + } + + @Test + @DisplayName("Test invalid number of partitions") + void testInvalidNumPartitions() { + assertThatThrownBy( + () -> + LanceIndexBuilder.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .numPartitions(0) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Number of partitions must be greater than 0"); + } + + @Test + @DisplayName("Test invalid number of sub-vectors") + void testInvalidNumSubVectors() { + assertThatThrownBy( + () -> + LanceIndexBuilder.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .numSubVectors(-1) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Number of sub-vectors must be greater than 0"); + } + + @Test + @DisplayName("Test invalid number of quantization bits") + void testInvalidNumBits() { + assertThatThrownBy( + () -> + LanceIndexBuilder.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .numBits(0) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Quantization bits must be between 1 and 16"); + + assertThatThrownBy( + () -> + LanceIndexBuilder.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .numBits(17) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Quantization bits must be between 1 and 16"); + } + + @Test + @DisplayName("Test default index configuration values") + void testDefaultIndexConfiguration() { + LanceOptions options = + LanceOptions.builder().path(datasetPath).indexColumn("embedding").build(); + + // Verify default values + assertThat(options.getIndexType()).isEqualTo(IndexType.IVF_PQ); + assertThat(options.getIndexNumPartitions()).isEqualTo(256); + assertThat(options.getIndexNumBits()).isEqualTo(8); + assertThat(options.getIndexMaxLevel()).isEqualTo(7); + assertThat(options.getIndexM()).isEqualTo(16); + assertThat(options.getIndexEfConstruction()).isEqualTo(100); + } + + @Test + @DisplayName("Test creating index builder from LanceOptions") + void testFromOptions() { + LanceOptions options = + LanceOptions.builder() + .path(datasetPath) + .indexColumn("embedding") + .indexType(IndexType.IVF_HNSW) + .indexNumPartitions(64) + .vectorMetric(MetricType.COSINE) + .build(); + + LanceIndexBuilder builder = LanceIndexBuilder.fromOptions(options); + + assertThat(builder).isNotNull(); + } + + @Test + @DisplayName("Test index build result") + void testIndexBuildResult() { + LanceIndexBuilder.IndexBuildResult result = + new LanceIndexBuilder.IndexBuildResult( + true, IndexType.IVF_PQ, "embedding", datasetPath, 1000, null); + + assertThat(result.isSuccess()).isTrue(); + assertThat(result.getIndexType()).isEqualTo(IndexType.IVF_PQ); + assertThat(result.getColumnName()).isEqualTo("embedding"); + assertThat(result.getDatasetPath()).isEqualTo(datasetPath); + assertThat(result.getDurationMillis()).isEqualTo(1000); + assertThat(result.getErrorMessage()).isNull(); + } + + @Test + @DisplayName("Test index build failure result") + void testIndexBuildFailureResult() { + LanceIndexBuilder.IndexBuildResult result = + new LanceIndexBuilder.IndexBuildResult( + false, IndexType.IVF_PQ, "embedding", datasetPath, 500, "Column does not exist"); + + assertThat(result.isSuccess()).isFalse(); + assertThat(result.getErrorMessage()).isEqualTo("Column does not exist"); + } + + @Test + @DisplayName("Test replace index option") + void testReplaceIndexOption() { + LanceIndexBuilder builder = + LanceIndexBuilder.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .indexType(IndexType.IVF_PQ) + .replace(true) + .build(); + + assertThat(builder).isNotNull(); + } } diff --git a/src/test/java/org/apache/flink/connector/lance/LanceSinkTest.java b/src/test/java/org/apache/flink/connector/lance/LanceSinkTest.java index 67252fb..e65b845 100644 --- a/src/test/java/org/apache/flink/connector/lance/LanceSinkTest.java +++ b/src/test/java/org/apache/flink/connector/lance/LanceSinkTest.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.connector.lance.config.LanceOptions; @@ -37,177 +32,159 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -/** - * LanceSink unit tests. - */ +/** LanceSink unit tests. */ class LanceSinkTest { - @TempDir - Path tempDir; - - private String datasetPath; - private RowType rowType; - - @BeforeEach - void setUp() { - datasetPath = tempDir.resolve("test_sink_dataset").toString(); - - // Create test RowType - List fields = new ArrayList<>(); - fields.add(new RowType.RowField("id", new BigIntType())); - fields.add(new RowType.RowField("content", new VarCharType())); - fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); - rowType = new RowType(fields); - } - - @Test - @DisplayName("Test LanceSink configuration build") - void testSinkConfiguration() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .writeBatchSize(512) - .writeMode(LanceOptions.WriteMode.APPEND) - .writeMaxRowsPerFile(500000) - .build(); - - LanceSink sink = new LanceSink(options, rowType); - - assertThat(sink.getOptions().getPath()).isEqualTo(datasetPath); - assertThat(sink.getOptions().getWriteBatchSize()).isEqualTo(512); - assertThat(sink.getOptions().getWriteMode()).isEqualTo(LanceOptions.WriteMode.APPEND); - assertThat(sink.getOptions().getWriteMaxRowsPerFile()).isEqualTo(500000); - assertThat(sink.getRowType()).isEqualTo(rowType); - } - - @Test - @DisplayName("Test LanceSink Builder pattern") - void testSinkBuilder() { - LanceSink sink = LanceSink.builder() - .path(datasetPath) - .batchSize(256) - .writeMode(LanceOptions.WriteMode.OVERWRITE) - .maxRowsPerFile(100000) - .rowType(rowType) - .build(); - - assertThat(sink.getOptions().getPath()).isEqualTo(datasetPath); - assertThat(sink.getOptions().getWriteBatchSize()).isEqualTo(256); - assertThat(sink.getOptions().getWriteMode()).isEqualTo(LanceOptions.WriteMode.OVERWRITE); - assertThat(sink.getOptions().getWriteMaxRowsPerFile()).isEqualTo(100000); - } - - @Test - @DisplayName("Test LanceSink Builder throws exception when missing path") - void testSinkBuilderMissingPath() { - assertThatThrownBy(() -> LanceSink.builder() - .rowType(rowType) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Dataset path cannot be empty"); - } - - @Test - @DisplayName("Test LanceSink Builder throws exception when missing RowType") - void testSinkBuilderMissingRowType() { - assertThatThrownBy(() -> LanceSink.builder() - .path(datasetPath) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("RowType"); - } - - @Test - @DisplayName("Test default Sink configuration values") - void testDefaultSinkConfiguration() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .build(); - - // Verify default values - assertThat(options.getWriteBatchSize()).isEqualTo(1024); - assertThat(options.getWriteMode()).isEqualTo(LanceOptions.WriteMode.APPEND); - assertThat(options.getWriteMaxRowsPerFile()).isEqualTo(1000000); - } - - @Test - @DisplayName("Test write mode enum") - void testWriteMode() { - assertThat(LanceOptions.WriteMode.fromValue("append")) - .isEqualTo(LanceOptions.WriteMode.APPEND); - assertThat(LanceOptions.WriteMode.fromValue("APPEND")) - .isEqualTo(LanceOptions.WriteMode.APPEND); - assertThat(LanceOptions.WriteMode.fromValue("overwrite")) - .isEqualTo(LanceOptions.WriteMode.OVERWRITE); - assertThat(LanceOptions.WriteMode.fromValue("OVERWRITE")) - .isEqualTo(LanceOptions.WriteMode.OVERWRITE); - } - - @Test - @DisplayName("Test invalid write mode") - void testInvalidWriteMode() { - assertThatThrownBy(() -> LanceOptions.WriteMode.fromValue("invalid")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Unsupported write mode"); - } - - @Test - @DisplayName("Test configuration validation - invalid write batch size") - void testInvalidWriteBatchSize() { - assertThatThrownBy(() -> LanceOptions.builder() - .path(datasetPath) - .writeBatchSize(0) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("batch-size"); - } - - @Test - @DisplayName("Test configuration validation - invalid max rows per file") - void testInvalidMaxRowsPerFile() { - assertThatThrownBy(() -> LanceOptions.builder() - .path(datasetPath) - .writeMaxRowsPerFile(-1) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("max-rows-per-file"); - } - - @Test - @DisplayName("Test vector type write configuration") - void testVectorWriteConfiguration() { - List fields = new ArrayList<>(); - fields.add(new RowType.RowField("id", new BigIntType())); - fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); - RowType vectorRowType = new RowType(fields); - - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .writeBatchSize(100) - .build(); - - LanceSink sink = new LanceSink(options, vectorRowType); - - assertThat(sink.getRowType().getFieldCount()).isEqualTo(2); - assertThat(sink.getRowType().getTypeAt(1)).isInstanceOf(ArrayType.class); - } - - @Test - @DisplayName("Test APPEND and OVERWRITE mode configuration") - void testWriteModeConfiguration() { - // APPEND mode - LanceOptions appendOptions = LanceOptions.builder() - .path(datasetPath) - .writeMode(LanceOptions.WriteMode.APPEND) - .build(); - assertThat(appendOptions.getWriteMode()).isEqualTo(LanceOptions.WriteMode.APPEND); - assertThat(appendOptions.getWriteMode().getValue()).isEqualTo("append"); - - // OVERWRITE mode - LanceOptions overwriteOptions = LanceOptions.builder() - .path(datasetPath) - .writeMode(LanceOptions.WriteMode.OVERWRITE) - .build(); - assertThat(overwriteOptions.getWriteMode()).isEqualTo(LanceOptions.WriteMode.OVERWRITE); - assertThat(overwriteOptions.getWriteMode().getValue()).isEqualTo("overwrite"); - } + @TempDir Path tempDir; + + private String datasetPath; + private RowType rowType; + + @BeforeEach + void setUp() { + datasetPath = tempDir.resolve("test_sink_dataset").toString(); + + // Create test RowType + List fields = new ArrayList<>(); + fields.add(new RowType.RowField("id", new BigIntType())); + fields.add(new RowType.RowField("content", new VarCharType())); + fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); + rowType = new RowType(fields); + } + + @Test + @DisplayName("Test LanceSink configuration build") + void testSinkConfiguration() { + LanceOptions options = + LanceOptions.builder() + .path(datasetPath) + .writeBatchSize(512) + .writeMode(LanceOptions.WriteMode.APPEND) + .writeMaxRowsPerFile(500000) + .build(); + + LanceSink sink = new LanceSink(options, rowType); + + assertThat(sink.getOptions().getPath()).isEqualTo(datasetPath); + assertThat(sink.getOptions().getWriteBatchSize()).isEqualTo(512); + assertThat(sink.getOptions().getWriteMode()).isEqualTo(LanceOptions.WriteMode.APPEND); + assertThat(sink.getOptions().getWriteMaxRowsPerFile()).isEqualTo(500000); + assertThat(sink.getRowType()).isEqualTo(rowType); + } + + @Test + @DisplayName("Test LanceSink Builder pattern") + void testSinkBuilder() { + LanceSink sink = + LanceSink.builder() + .path(datasetPath) + .batchSize(256) + .writeMode(LanceOptions.WriteMode.OVERWRITE) + .maxRowsPerFile(100000) + .rowType(rowType) + .build(); + + assertThat(sink.getOptions().getPath()).isEqualTo(datasetPath); + assertThat(sink.getOptions().getWriteBatchSize()).isEqualTo(256); + assertThat(sink.getOptions().getWriteMode()).isEqualTo(LanceOptions.WriteMode.OVERWRITE); + assertThat(sink.getOptions().getWriteMaxRowsPerFile()).isEqualTo(100000); + } + + @Test + @DisplayName("Test LanceSink Builder throws exception when missing path") + void testSinkBuilderMissingPath() { + assertThatThrownBy(() -> LanceSink.builder().rowType(rowType).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Dataset path cannot be empty"); + } + + @Test + @DisplayName("Test LanceSink Builder throws exception when missing RowType") + void testSinkBuilderMissingRowType() { + assertThatThrownBy(() -> LanceSink.builder().path(datasetPath).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("RowType"); + } + + @Test + @DisplayName("Test default Sink configuration values") + void testDefaultSinkConfiguration() { + LanceOptions options = LanceOptions.builder().path(datasetPath).build(); + + // Verify default values + assertThat(options.getWriteBatchSize()).isEqualTo(1024); + assertThat(options.getWriteMode()).isEqualTo(LanceOptions.WriteMode.APPEND); + assertThat(options.getWriteMaxRowsPerFile()).isEqualTo(1000000); + } + + @Test + @DisplayName("Test write mode enum") + void testWriteMode() { + assertThat(LanceOptions.WriteMode.fromValue("append")).isEqualTo(LanceOptions.WriteMode.APPEND); + assertThat(LanceOptions.WriteMode.fromValue("APPEND")).isEqualTo(LanceOptions.WriteMode.APPEND); + assertThat(LanceOptions.WriteMode.fromValue("overwrite")) + .isEqualTo(LanceOptions.WriteMode.OVERWRITE); + assertThat(LanceOptions.WriteMode.fromValue("OVERWRITE")) + .isEqualTo(LanceOptions.WriteMode.OVERWRITE); + } + + @Test + @DisplayName("Test invalid write mode") + void testInvalidWriteMode() { + assertThatThrownBy(() -> LanceOptions.WriteMode.fromValue("invalid")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unsupported write mode"); + } + + @Test + @DisplayName("Test configuration validation - invalid write batch size") + void testInvalidWriteBatchSize() { + assertThatThrownBy(() -> LanceOptions.builder().path(datasetPath).writeBatchSize(0).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("batch-size"); + } + + @Test + @DisplayName("Test configuration validation - invalid max rows per file") + void testInvalidMaxRowsPerFile() { + assertThatThrownBy( + () -> LanceOptions.builder().path(datasetPath).writeMaxRowsPerFile(-1).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("max-rows-per-file"); + } + + @Test + @DisplayName("Test vector type write configuration") + void testVectorWriteConfiguration() { + List fields = new ArrayList<>(); + fields.add(new RowType.RowField("id", new BigIntType())); + fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); + RowType vectorRowType = new RowType(fields); + + LanceOptions options = LanceOptions.builder().path(datasetPath).writeBatchSize(100).build(); + + LanceSink sink = new LanceSink(options, vectorRowType); + + assertThat(sink.getRowType().getFieldCount()).isEqualTo(2); + assertThat(sink.getRowType().getTypeAt(1)).isInstanceOf(ArrayType.class); + } + + @Test + @DisplayName("Test APPEND and OVERWRITE mode configuration") + void testWriteModeConfiguration() { + // APPEND mode + LanceOptions appendOptions = + LanceOptions.builder().path(datasetPath).writeMode(LanceOptions.WriteMode.APPEND).build(); + assertThat(appendOptions.getWriteMode()).isEqualTo(LanceOptions.WriteMode.APPEND); + assertThat(appendOptions.getWriteMode().getValue()).isEqualTo("append"); + + // OVERWRITE mode + LanceOptions overwriteOptions = + LanceOptions.builder() + .path(datasetPath) + .writeMode(LanceOptions.WriteMode.OVERWRITE) + .build(); + assertThat(overwriteOptions.getWriteMode()).isEqualTo(LanceOptions.WriteMode.OVERWRITE); + assertThat(overwriteOptions.getWriteMode().getValue()).isEqualTo("overwrite"); + } } diff --git a/src/test/java/org/apache/flink/connector/lance/LanceSourceTest.java b/src/test/java/org/apache/flink/connector/lance/LanceSourceTest.java index 7b77a04..c25f342 100644 --- a/src/test/java/org/apache/flink/connector/lance/LanceSourceTest.java +++ b/src/test/java/org/apache/flink/connector/lance/LanceSourceTest.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.connector.lance.config.LanceOptions; @@ -38,151 +33,138 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -/** - * LanceSource unit tests. - */ +/** LanceSource unit tests. */ class LanceSourceTest { - @TempDir - Path tempDir; - - private String datasetPath; - private RowType rowType; - - @BeforeEach - void setUp() { - datasetPath = tempDir.resolve("test_dataset").toString(); - - // Create test RowType - List fields = new ArrayList<>(); - fields.add(new RowType.RowField("id", new BigIntType())); - fields.add(new RowType.RowField("content", new VarCharType())); - fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); - rowType = new RowType(fields); - } - - @Test - @DisplayName("Test LanceSource configuration build") - void testSourceConfiguration() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .readBatchSize(512) - .readColumns(Arrays.asList("id", "content")) - .readFilter("id > 10") - .build(); - - LanceSource source = new LanceSource(options, rowType); - - assertThat(source.getOptions().getPath()).isEqualTo(datasetPath); - assertThat(source.getOptions().getReadBatchSize()).isEqualTo(512); - assertThat(source.getOptions().getReadColumns()).containsExactly("id", "content"); - assertThat(source.getOptions().getReadFilter()).isEqualTo("id > 10"); - assertThat(source.getRowType()).isEqualTo(rowType); - } - - @Test - @DisplayName("Test LanceSource Builder pattern") - void testSourceBuilder() { - LanceSource source = LanceSource.builder() - .path(datasetPath) - .batchSize(256) - .columns(Arrays.asList("id")) - .filter("id < 100") - .rowType(rowType) - .build(); - - assertThat(source.getOptions().getPath()).isEqualTo(datasetPath); - assertThat(source.getOptions().getReadBatchSize()).isEqualTo(256); - assertThat(source.getSelectedColumns()).containsExactly("id"); - } - - @Test - @DisplayName("Test LanceSource Builder throws exception when missing path") - void testSourceBuilderMissingPath() { - assertThatThrownBy(() -> LanceSource.builder() - .rowType(rowType) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Dataset path cannot be empty"); - } - - @Test - @DisplayName("Test LanceSplit creation") - void testLanceSplit() { - LanceSplit split = new LanceSplit(0, 1, datasetPath, 1000); - - assertThat(split.getSplitNumber()).isEqualTo(0); - assertThat(split.getFragmentId()).isEqualTo(1); - assertThat(split.getDatasetPath()).isEqualTo(datasetPath); - assertThat(split.getRowCount()).isEqualTo(1000); - } - - @Test - @DisplayName("Test LanceSplit equality") - void testLanceSplitEquality() { - LanceSplit split1 = new LanceSplit(0, 1, datasetPath, 1000); - LanceSplit split2 = new LanceSplit(0, 1, datasetPath, 1000); - LanceSplit split3 = new LanceSplit(1, 2, datasetPath, 2000); - - assertThat(split1).isEqualTo(split2); - assertThat(split1.hashCode()).isEqualTo(split2.hashCode()); - assertThat(split1).isNotEqualTo(split3); - } - - @Test - @DisplayName("Test LanceInputFormat configuration") - void testInputFormatConfiguration() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .readBatchSize(128) - .build(); - - LanceInputFormat inputFormat = new LanceInputFormat(options, rowType); - - assertThat(inputFormat.getOptions().getPath()).isEqualTo(datasetPath); - assertThat(inputFormat.getOptions().getReadBatchSize()).isEqualTo(128); - assertThat(inputFormat.getRowType()).isEqualTo(rowType); - } - - @Test - @DisplayName("Test default configuration values") - void testDefaultConfiguration() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .build(); - - // Verify default values - assertThat(options.getReadBatchSize()).isEqualTo(1024); - assertThat(options.getReadColumns()).isEmpty(); - assertThat(options.getReadFilter()).isNull(); - } - - @Test - @DisplayName("Test configuration validation - invalid batch size") - void testInvalidBatchSize() { - assertThatThrownBy(() -> LanceOptions.builder() - .path(datasetPath) - .readBatchSize(0) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("batch-size"); - } - - @Test - @DisplayName("Test vector type RowType") - void testVectorRowType() { - List fields = new ArrayList<>(); - fields.add(new RowType.RowField("id", new BigIntType())); - fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); - RowType vectorRowType = new RowType(fields); - - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .build(); - - LanceSource source = new LanceSource(options, vectorRowType); - - assertThat(source.getRowType().getFieldCount()).isEqualTo(2); - assertThat(source.getRowType().getTypeAt(1)).isInstanceOf(ArrayType.class); - } + @TempDir Path tempDir; + + private String datasetPath; + private RowType rowType; + + @BeforeEach + void setUp() { + datasetPath = tempDir.resolve("test_dataset").toString(); + + // Create test RowType + List fields = new ArrayList<>(); + fields.add(new RowType.RowField("id", new BigIntType())); + fields.add(new RowType.RowField("content", new VarCharType())); + fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); + rowType = new RowType(fields); + } + + @Test + @DisplayName("Test LanceSource configuration build") + void testSourceConfiguration() { + LanceOptions options = + LanceOptions.builder() + .path(datasetPath) + .readBatchSize(512) + .readColumns(Arrays.asList("id", "content")) + .readFilter("id > 10") + .build(); + + LanceSource source = new LanceSource(options, rowType); + + assertThat(source.getOptions().getPath()).isEqualTo(datasetPath); + assertThat(source.getOptions().getReadBatchSize()).isEqualTo(512); + assertThat(source.getOptions().getReadColumns()).containsExactly("id", "content"); + assertThat(source.getOptions().getReadFilter()).isEqualTo("id > 10"); + assertThat(source.getRowType()).isEqualTo(rowType); + } + + @Test + @DisplayName("Test LanceSource Builder pattern") + void testSourceBuilder() { + LanceSource source = + LanceSource.builder() + .path(datasetPath) + .batchSize(256) + .columns(Arrays.asList("id")) + .filter("id < 100") + .rowType(rowType) + .build(); + + assertThat(source.getOptions().getPath()).isEqualTo(datasetPath); + assertThat(source.getOptions().getReadBatchSize()).isEqualTo(256); + assertThat(source.getSelectedColumns()).containsExactly("id"); + } + + @Test + @DisplayName("Test LanceSource Builder throws exception when missing path") + void testSourceBuilderMissingPath() { + assertThatThrownBy(() -> LanceSource.builder().rowType(rowType).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Dataset path cannot be empty"); + } + + @Test + @DisplayName("Test LanceSplit creation") + void testLanceSplit() { + LanceSplit split = new LanceSplit(0, 1, datasetPath, 1000); + + assertThat(split.getSplitNumber()).isEqualTo(0); + assertThat(split.getFragmentId()).isEqualTo(1); + assertThat(split.getDatasetPath()).isEqualTo(datasetPath); + assertThat(split.getRowCount()).isEqualTo(1000); + } + + @Test + @DisplayName("Test LanceSplit equality") + void testLanceSplitEquality() { + LanceSplit split1 = new LanceSplit(0, 1, datasetPath, 1000); + LanceSplit split2 = new LanceSplit(0, 1, datasetPath, 1000); + LanceSplit split3 = new LanceSplit(1, 2, datasetPath, 2000); + + assertThat(split1).isEqualTo(split2); + assertThat(split1.hashCode()).isEqualTo(split2.hashCode()); + assertThat(split1).isNotEqualTo(split3); + } + + @Test + @DisplayName("Test LanceInputFormat configuration") + void testInputFormatConfiguration() { + LanceOptions options = LanceOptions.builder().path(datasetPath).readBatchSize(128).build(); + + LanceInputFormat inputFormat = new LanceInputFormat(options, rowType); + + assertThat(inputFormat.getOptions().getPath()).isEqualTo(datasetPath); + assertThat(inputFormat.getOptions().getReadBatchSize()).isEqualTo(128); + assertThat(inputFormat.getRowType()).isEqualTo(rowType); + } + + @Test + @DisplayName("Test default configuration values") + void testDefaultConfiguration() { + LanceOptions options = LanceOptions.builder().path(datasetPath).build(); + + // Verify default values + assertThat(options.getReadBatchSize()).isEqualTo(1024); + assertThat(options.getReadColumns()).isEmpty(); + assertThat(options.getReadFilter()).isNull(); + } + + @Test + @DisplayName("Test configuration validation - invalid batch size") + void testInvalidBatchSize() { + assertThatThrownBy(() -> LanceOptions.builder().path(datasetPath).readBatchSize(0).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("batch-size"); + } + + @Test + @DisplayName("Test vector type RowType") + void testVectorRowType() { + List fields = new ArrayList<>(); + fields.add(new RowType.RowField("id", new BigIntType())); + fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); + RowType vectorRowType = new RowType(fields); + + LanceOptions options = LanceOptions.builder().path(datasetPath).build(); + + LanceSource source = new LanceSource(options, vectorRowType); + + assertThat(source.getRowType().getFieldCount()).isEqualTo(2); + assertThat(source.getRowType().getTypeAt(1)).isInstanceOf(ArrayType.class); + } } diff --git a/src/test/java/org/apache/flink/connector/lance/LanceTypeConverterTest.java b/src/test/java/org/apache/flink/connector/lance/LanceTypeConverterTest.java index aa35b6d..fe36307 100644 --- a/src/test/java/org/apache/flink/connector/lance/LanceTypeConverterTest.java +++ b/src/test/java/org/apache/flink/connector/lance/LanceTypeConverterTest.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.connector.lance.converter.LanceTypeConverter; @@ -51,266 +46,279 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -/** - * LanceTypeConverter unit tests. - */ +/** LanceTypeConverter unit tests. */ class LanceTypeConverterTest { - @Test - @DisplayName("Test Arrow Int type to Flink type mapping") - void testArrowIntToFlinkType() { - // Int8 -> TINYINT - Field int8Field = new Field("int8", FieldType.nullable(new ArrowType.Int(8, true)), null); - LogicalType int8Type = LanceTypeConverter.arrowTypeToFlinkType(int8Field); - assertThat(int8Type).isInstanceOf(TinyIntType.class); - - // Int16 -> SMALLINT - Field int16Field = new Field("int16", FieldType.nullable(new ArrowType.Int(16, true)), null); - LogicalType int16Type = LanceTypeConverter.arrowTypeToFlinkType(int16Field); - assertThat(int16Type).isInstanceOf(SmallIntType.class); - - // Int32 -> INT - Field int32Field = new Field("int32", FieldType.nullable(new ArrowType.Int(32, true)), null); - LogicalType int32Type = LanceTypeConverter.arrowTypeToFlinkType(int32Field); - assertThat(int32Type).isInstanceOf(IntType.class); - - // Int64 -> BIGINT - Field int64Field = new Field("int64", FieldType.nullable(new ArrowType.Int(64, true)), null); - LogicalType int64Type = LanceTypeConverter.arrowTypeToFlinkType(int64Field); - assertThat(int64Type).isInstanceOf(BigIntType.class); - } - - @Test - @DisplayName("Test Arrow floating point type to Flink type mapping") - void testArrowFloatToFlinkType() { - // Float32 -> FLOAT - Field float32Field = new Field("float32", - FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null); - LogicalType float32Type = LanceTypeConverter.arrowTypeToFlinkType(float32Field); - assertThat(float32Type).isInstanceOf(FloatType.class); - - // Float64 -> DOUBLE - Field float64Field = new Field("float64", - FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null); - LogicalType float64Type = LanceTypeConverter.arrowTypeToFlinkType(float64Field); - assertThat(float64Type).isInstanceOf(DoubleType.class); - } - - @Test - @DisplayName("Test Arrow string type to Flink type mapping") - void testArrowStringToFlinkType() { - // String -> STRING - Field stringField = new Field("str", FieldType.nullable(ArrowType.Utf8.INSTANCE), null); - LogicalType stringType = LanceTypeConverter.arrowTypeToFlinkType(stringField); - assertThat(stringType).isInstanceOf(VarCharType.class); - - // LargeString -> STRING - Field largeStringField = new Field("large_str", FieldType.nullable(ArrowType.LargeUtf8.INSTANCE), null); - LogicalType largeStringType = LanceTypeConverter.arrowTypeToFlinkType(largeStringField); - assertThat(largeStringType).isInstanceOf(VarCharType.class); - } - - @Test - @DisplayName("Test Arrow Boolean type to Flink type mapping") - void testArrowBoolToFlinkType() { - Field boolField = new Field("bool", FieldType.nullable(ArrowType.Bool.INSTANCE), null); - LogicalType boolType = LanceTypeConverter.arrowTypeToFlinkType(boolField); - assertThat(boolType).isInstanceOf(BooleanType.class); - } - - @Test - @DisplayName("Test Arrow Binary type to Flink type mapping") - void testArrowBinaryToFlinkType() { - Field binaryField = new Field("binary", FieldType.nullable(ArrowType.Binary.INSTANCE), null); - LogicalType binaryType = LanceTypeConverter.arrowTypeToFlinkType(binaryField); - assertThat(binaryType).isInstanceOf(VarBinaryType.class); - } - - @Test - @DisplayName("Test Arrow Date type to Flink type mapping") - void testArrowDateToFlinkType() { - Field dateField = new Field("date", - FieldType.nullable(new ArrowType.Date(DateUnit.DAY)), null); - LogicalType dateType = LanceTypeConverter.arrowTypeToFlinkType(dateField); - assertThat(dateType).isInstanceOf(DateType.class); - } - - @Test - @DisplayName("Test Arrow Timestamp type to Flink type mapping") - void testArrowTimestampToFlinkType() { - // Millisecond precision - Field tsMilliField = new Field("ts_milli", - FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)), null); - LogicalType tsMilliType = LanceTypeConverter.arrowTypeToFlinkType(tsMilliField); - assertThat(tsMilliType).isInstanceOf(TimestampType.class); - assertThat(((TimestampType) tsMilliType).getPrecision()).isEqualTo(3); - - // Microsecond precision - Field tsMicroField = new Field("ts_micro", - FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)), null); - LogicalType tsMicroType = LanceTypeConverter.arrowTypeToFlinkType(tsMicroField); - assertThat(tsMicroType).isInstanceOf(TimestampType.class); - assertThat(((TimestampType) tsMicroType).getPrecision()).isEqualTo(6); - } - - @Test - @DisplayName("Test Arrow FixedSizeList (vector) type to Flink type mapping") - void testArrowVectorToFlinkType() { - // FixedSizeList -> ARRAY - ArrowType elementType = new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); - Field elementField = new Field("item", FieldType.notNullable(elementType), null); - List children = Arrays.asList(elementField); - - Field vectorField = new Field("embedding", - FieldType.nullable(new ArrowType.FixedSizeList(128)), children); - - LogicalType vectorType = LanceTypeConverter.arrowTypeToFlinkType(vectorField); - assertThat(vectorType).isInstanceOf(ArrayType.class); - assertThat(((ArrayType) vectorType).getElementType()).isInstanceOf(FloatType.class); - } - - @Test - @DisplayName("Test Flink type to Arrow type mapping") - void testFlinkTypeToArrowType() { - // TINYINT -> Int8 - Field tinyIntField = LanceTypeConverter.flinkTypeToArrowField("tinyint", new TinyIntType()); - assertThat(tinyIntField.getType()).isInstanceOf(ArrowType.Int.class); - assertThat(((ArrowType.Int) tinyIntField.getType()).getBitWidth()).isEqualTo(8); - - // INT -> Int32 - Field intField = LanceTypeConverter.flinkTypeToArrowField("int", new IntType()); - assertThat(intField.getType()).isInstanceOf(ArrowType.Int.class); - assertThat(((ArrowType.Int) intField.getType()).getBitWidth()).isEqualTo(32); - - // BIGINT -> Int64 - Field bigIntField = LanceTypeConverter.flinkTypeToArrowField("bigint", new BigIntType()); - assertThat(bigIntField.getType()).isInstanceOf(ArrowType.Int.class); - assertThat(((ArrowType.Int) bigIntField.getType()).getBitWidth()).isEqualTo(64); - - // FLOAT -> Float32 - Field floatField = LanceTypeConverter.flinkTypeToArrowField("float", new FloatType()); - assertThat(floatField.getType()).isInstanceOf(ArrowType.FloatingPoint.class); - assertThat(((ArrowType.FloatingPoint) floatField.getType()).getPrecision()) - .isEqualTo(FloatingPointPrecision.SINGLE); - - // DOUBLE -> Float64 - Field doubleField = LanceTypeConverter.flinkTypeToArrowField("double", new DoubleType()); - assertThat(doubleField.getType()).isInstanceOf(ArrowType.FloatingPoint.class); - assertThat(((ArrowType.FloatingPoint) doubleField.getType()).getPrecision()) - .isEqualTo(FloatingPointPrecision.DOUBLE); - - // STRING -> Utf8 - Field stringField = LanceTypeConverter.flinkTypeToArrowField("string", new VarCharType()); - assertThat(stringField.getType()).isInstanceOf(ArrowType.Utf8.class); - - // BOOLEAN -> Bool - Field boolField = LanceTypeConverter.flinkTypeToArrowField("bool", new BooleanType()); - assertThat(boolField.getType()).isInstanceOf(ArrowType.Bool.class); - } - - @Test - @DisplayName("Test Arrow Schema to Flink RowType conversion") - void testArrowSchemaToFlinkRowType() { - List fields = new ArrayList<>(); - fields.add(new Field("id", FieldType.notNullable(new ArrowType.Int(64, true)), null)); - fields.add(new Field("name", FieldType.nullable(ArrowType.Utf8.INSTANCE), null)); - fields.add(new Field("score", - FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null)); - - Schema arrowSchema = new Schema(fields); - RowType rowType = LanceTypeConverter.toFlinkRowType(arrowSchema); - - assertThat(rowType.getFieldCount()).isEqualTo(3); - assertThat(rowType.getFieldNames()).containsExactly("id", "name", "score"); - assertThat(rowType.getTypeAt(0)).isInstanceOf(BigIntType.class); - assertThat(rowType.getTypeAt(1)).isInstanceOf(VarCharType.class); - assertThat(rowType.getTypeAt(2)).isInstanceOf(DoubleType.class); - } - - @Test - @DisplayName("Test Flink RowType to Arrow Schema conversion") - void testFlinkRowTypeToArrowSchema() { - List fields = new ArrayList<>(); - fields.add(new RowType.RowField("id", new BigIntType(false))); - fields.add(new RowType.RowField("content", new VarCharType())); - fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); - - RowType rowType = new RowType(fields); - Schema arrowSchema = LanceTypeConverter.toArrowSchema(rowType); - - assertThat(arrowSchema.getFields()).hasSize(3); - assertThat(arrowSchema.getFields().get(0).getName()).isEqualTo("id"); - assertThat(arrowSchema.getFields().get(1).getName()).isEqualTo("content"); - assertThat(arrowSchema.getFields().get(2).getName()).isEqualTo("embedding"); - } - - @Test - @DisplayName("Test vector field creation") - void testCreateVectorField() { - // Float32 vector - Field float32Vector = LanceTypeConverter.createVectorField("embedding", 128, false); - assertThat(float32Vector.getName()).isEqualTo("embedding"); - assertThat(float32Vector.getType()).isInstanceOf(ArrowType.FixedSizeList.class); - assertThat(((ArrowType.FixedSizeList) float32Vector.getType()).getListSize()).isEqualTo(128); - assertThat(float32Vector.isNullable()).isFalse(); - - // Float64 vector - Field float64Vector = LanceTypeConverter.createFloat64VectorField("embedding64", 256, true); - assertThat(float64Vector.getName()).isEqualTo("embedding64"); - assertThat(((ArrowType.FixedSizeList) float64Vector.getType()).getListSize()).isEqualTo(256); - assertThat(float64Vector.isNullable()).isTrue(); - } - - @Test - @DisplayName("Test vector field detection") - void testIsVectorField() { - // Create vector field - Field vectorField = LanceTypeConverter.createVectorField("embedding", 128, false); - assertThat(LanceTypeConverter.isVectorField(vectorField)).isTrue(); - assertThat(LanceTypeConverter.getVectorDimension(vectorField)).isEqualTo(128); - - // Non-vector field - Field intField = new Field("id", FieldType.notNullable(new ArrowType.Int(64, true)), null); - assertThat(LanceTypeConverter.isVectorField(intField)).isFalse(); - assertThat(LanceTypeConverter.getVectorDimension(intField)).isEqualTo(-1); - } - - @Test - @DisplayName("Test unsupported type exception") - void testUnsupportedTypeException() { - // Unsupported Arrow type - Field unsupportedField = new Field("unsupported", - FieldType.nullable(new ArrowType.Duration(TimeUnit.SECOND)), null); - - assertThatThrownBy(() -> LanceTypeConverter.arrowTypeToFlinkType(unsupportedField)) - .isInstanceOf(LanceTypeConverter.UnsupportedTypeException.class) - .hasMessageContaining("Unsupported Arrow type"); - } - - @Test - @DisplayName("Test round-trip conversion consistency") - void testRoundTripConversion() { - // Create Flink RowType - List fields = new ArrayList<>(); - fields.add(new RowType.RowField("id", new BigIntType(false))); - fields.add(new RowType.RowField("name", new VarCharType())); - fields.add(new RowType.RowField("score", new DoubleType())); - fields.add(new RowType.RowField("active", new BooleanType())); - - RowType originalRowType = new RowType(fields); - - // Flink -> Arrow -> Flink - Schema arrowSchema = LanceTypeConverter.toArrowSchema(originalRowType); - RowType convertedRowType = LanceTypeConverter.toFlinkRowType(arrowSchema); - - // Verify field count and names - assertThat(convertedRowType.getFieldCount()).isEqualTo(originalRowType.getFieldCount()); - assertThat(convertedRowType.getFieldNames()).isEqualTo(originalRowType.getFieldNames()); - - // Verify types (type classes should match) - for (int i = 0; i < originalRowType.getFieldCount(); i++) { - assertThat(convertedRowType.getTypeAt(i).getClass()) - .isEqualTo(originalRowType.getTypeAt(i).getClass()); - } + @Test + @DisplayName("Test Arrow Int type to Flink type mapping") + void testArrowIntToFlinkType() { + // Int8 -> TINYINT + Field int8Field = new Field("int8", FieldType.nullable(new ArrowType.Int(8, true)), null); + LogicalType int8Type = LanceTypeConverter.arrowTypeToFlinkType(int8Field); + assertThat(int8Type).isInstanceOf(TinyIntType.class); + + // Int16 -> SMALLINT + Field int16Field = new Field("int16", FieldType.nullable(new ArrowType.Int(16, true)), null); + LogicalType int16Type = LanceTypeConverter.arrowTypeToFlinkType(int16Field); + assertThat(int16Type).isInstanceOf(SmallIntType.class); + + // Int32 -> INT + Field int32Field = new Field("int32", FieldType.nullable(new ArrowType.Int(32, true)), null); + LogicalType int32Type = LanceTypeConverter.arrowTypeToFlinkType(int32Field); + assertThat(int32Type).isInstanceOf(IntType.class); + + // Int64 -> BIGINT + Field int64Field = new Field("int64", FieldType.nullable(new ArrowType.Int(64, true)), null); + LogicalType int64Type = LanceTypeConverter.arrowTypeToFlinkType(int64Field); + assertThat(int64Type).isInstanceOf(BigIntType.class); + } + + @Test + @DisplayName("Test Arrow floating point type to Flink type mapping") + void testArrowFloatToFlinkType() { + // Float32 -> FLOAT + Field float32Field = + new Field( + "float32", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), + null); + LogicalType float32Type = LanceTypeConverter.arrowTypeToFlinkType(float32Field); + assertThat(float32Type).isInstanceOf(FloatType.class); + + // Float64 -> DOUBLE + Field float64Field = + new Field( + "float64", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + null); + LogicalType float64Type = LanceTypeConverter.arrowTypeToFlinkType(float64Field); + assertThat(float64Type).isInstanceOf(DoubleType.class); + } + + @Test + @DisplayName("Test Arrow string type to Flink type mapping") + void testArrowStringToFlinkType() { + // String -> STRING + Field stringField = new Field("str", FieldType.nullable(ArrowType.Utf8.INSTANCE), null); + LogicalType stringType = LanceTypeConverter.arrowTypeToFlinkType(stringField); + assertThat(stringType).isInstanceOf(VarCharType.class); + + // LargeString -> STRING + Field largeStringField = + new Field("large_str", FieldType.nullable(ArrowType.LargeUtf8.INSTANCE), null); + LogicalType largeStringType = LanceTypeConverter.arrowTypeToFlinkType(largeStringField); + assertThat(largeStringType).isInstanceOf(VarCharType.class); + } + + @Test + @DisplayName("Test Arrow Boolean type to Flink type mapping") + void testArrowBoolToFlinkType() { + Field boolField = new Field("bool", FieldType.nullable(ArrowType.Bool.INSTANCE), null); + LogicalType boolType = LanceTypeConverter.arrowTypeToFlinkType(boolField); + assertThat(boolType).isInstanceOf(BooleanType.class); + } + + @Test + @DisplayName("Test Arrow Binary type to Flink type mapping") + void testArrowBinaryToFlinkType() { + Field binaryField = new Field("binary", FieldType.nullable(ArrowType.Binary.INSTANCE), null); + LogicalType binaryType = LanceTypeConverter.arrowTypeToFlinkType(binaryField); + assertThat(binaryType).isInstanceOf(VarBinaryType.class); + } + + @Test + @DisplayName("Test Arrow Date type to Flink type mapping") + void testArrowDateToFlinkType() { + Field dateField = new Field("date", FieldType.nullable(new ArrowType.Date(DateUnit.DAY)), null); + LogicalType dateType = LanceTypeConverter.arrowTypeToFlinkType(dateField); + assertThat(dateType).isInstanceOf(DateType.class); + } + + @Test + @DisplayName("Test Arrow Timestamp type to Flink type mapping") + void testArrowTimestampToFlinkType() { + // Millisecond precision + Field tsMilliField = + new Field( + "ts_milli", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)), + null); + LogicalType tsMilliType = LanceTypeConverter.arrowTypeToFlinkType(tsMilliField); + assertThat(tsMilliType).isInstanceOf(TimestampType.class); + assertThat(((TimestampType) tsMilliType).getPrecision()).isEqualTo(3); + + // Microsecond precision + Field tsMicroField = + new Field( + "ts_micro", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)), + null); + LogicalType tsMicroType = LanceTypeConverter.arrowTypeToFlinkType(tsMicroField); + assertThat(tsMicroType).isInstanceOf(TimestampType.class); + assertThat(((TimestampType) tsMicroType).getPrecision()).isEqualTo(6); + } + + @Test + @DisplayName("Test Arrow FixedSizeList (vector) type to Flink type mapping") + void testArrowVectorToFlinkType() { + // FixedSizeList -> ARRAY + ArrowType elementType = new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + Field elementField = new Field("item", FieldType.notNullable(elementType), null); + List children = Arrays.asList(elementField); + + Field vectorField = + new Field("embedding", FieldType.nullable(new ArrowType.FixedSizeList(128)), children); + + LogicalType vectorType = LanceTypeConverter.arrowTypeToFlinkType(vectorField); + assertThat(vectorType).isInstanceOf(ArrayType.class); + assertThat(((ArrayType) vectorType).getElementType()).isInstanceOf(FloatType.class); + } + + @Test + @DisplayName("Test Flink type to Arrow type mapping") + void testFlinkTypeToArrowType() { + // TINYINT -> Int8 + Field tinyIntField = LanceTypeConverter.flinkTypeToArrowField("tinyint", new TinyIntType()); + assertThat(tinyIntField.getType()).isInstanceOf(ArrowType.Int.class); + assertThat(((ArrowType.Int) tinyIntField.getType()).getBitWidth()).isEqualTo(8); + + // INT -> Int32 + Field intField = LanceTypeConverter.flinkTypeToArrowField("int", new IntType()); + assertThat(intField.getType()).isInstanceOf(ArrowType.Int.class); + assertThat(((ArrowType.Int) intField.getType()).getBitWidth()).isEqualTo(32); + + // BIGINT -> Int64 + Field bigIntField = LanceTypeConverter.flinkTypeToArrowField("bigint", new BigIntType()); + assertThat(bigIntField.getType()).isInstanceOf(ArrowType.Int.class); + assertThat(((ArrowType.Int) bigIntField.getType()).getBitWidth()).isEqualTo(64); + + // FLOAT -> Float32 + Field floatField = LanceTypeConverter.flinkTypeToArrowField("float", new FloatType()); + assertThat(floatField.getType()).isInstanceOf(ArrowType.FloatingPoint.class); + assertThat(((ArrowType.FloatingPoint) floatField.getType()).getPrecision()) + .isEqualTo(FloatingPointPrecision.SINGLE); + + // DOUBLE -> Float64 + Field doubleField = LanceTypeConverter.flinkTypeToArrowField("double", new DoubleType()); + assertThat(doubleField.getType()).isInstanceOf(ArrowType.FloatingPoint.class); + assertThat(((ArrowType.FloatingPoint) doubleField.getType()).getPrecision()) + .isEqualTo(FloatingPointPrecision.DOUBLE); + + // STRING -> Utf8 + Field stringField = LanceTypeConverter.flinkTypeToArrowField("string", new VarCharType()); + assertThat(stringField.getType()).isInstanceOf(ArrowType.Utf8.class); + + // BOOLEAN -> Bool + Field boolField = LanceTypeConverter.flinkTypeToArrowField("bool", new BooleanType()); + assertThat(boolField.getType()).isInstanceOf(ArrowType.Bool.class); + } + + @Test + @DisplayName("Test Arrow Schema to Flink RowType conversion") + void testArrowSchemaToFlinkRowType() { + List fields = new ArrayList<>(); + fields.add(new Field("id", FieldType.notNullable(new ArrowType.Int(64, true)), null)); + fields.add(new Field("name", FieldType.nullable(ArrowType.Utf8.INSTANCE), null)); + fields.add( + new Field( + "score", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + null)); + + Schema arrowSchema = new Schema(fields); + RowType rowType = LanceTypeConverter.toFlinkRowType(arrowSchema); + + assertThat(rowType.getFieldCount()).isEqualTo(3); + assertThat(rowType.getFieldNames()).containsExactly("id", "name", "score"); + assertThat(rowType.getTypeAt(0)).isInstanceOf(BigIntType.class); + assertThat(rowType.getTypeAt(1)).isInstanceOf(VarCharType.class); + assertThat(rowType.getTypeAt(2)).isInstanceOf(DoubleType.class); + } + + @Test + @DisplayName("Test Flink RowType to Arrow Schema conversion") + void testFlinkRowTypeToArrowSchema() { + List fields = new ArrayList<>(); + fields.add(new RowType.RowField("id", new BigIntType(false))); + fields.add(new RowType.RowField("content", new VarCharType())); + fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); + + RowType rowType = new RowType(fields); + Schema arrowSchema = LanceTypeConverter.toArrowSchema(rowType); + + assertThat(arrowSchema.getFields()).hasSize(3); + assertThat(arrowSchema.getFields().get(0).getName()).isEqualTo("id"); + assertThat(arrowSchema.getFields().get(1).getName()).isEqualTo("content"); + assertThat(arrowSchema.getFields().get(2).getName()).isEqualTo("embedding"); + } + + @Test + @DisplayName("Test vector field creation") + void testCreateVectorField() { + // Float32 vector + Field float32Vector = LanceTypeConverter.createVectorField("embedding", 128, false); + assertThat(float32Vector.getName()).isEqualTo("embedding"); + assertThat(float32Vector.getType()).isInstanceOf(ArrowType.FixedSizeList.class); + assertThat(((ArrowType.FixedSizeList) float32Vector.getType()).getListSize()).isEqualTo(128); + assertThat(float32Vector.isNullable()).isFalse(); + + // Float64 vector + Field float64Vector = LanceTypeConverter.createFloat64VectorField("embedding64", 256, true); + assertThat(float64Vector.getName()).isEqualTo("embedding64"); + assertThat(((ArrowType.FixedSizeList) float64Vector.getType()).getListSize()).isEqualTo(256); + assertThat(float64Vector.isNullable()).isTrue(); + } + + @Test + @DisplayName("Test vector field detection") + void testIsVectorField() { + // Create vector field + Field vectorField = LanceTypeConverter.createVectorField("embedding", 128, false); + assertThat(LanceTypeConverter.isVectorField(vectorField)).isTrue(); + assertThat(LanceTypeConverter.getVectorDimension(vectorField)).isEqualTo(128); + + // Non-vector field + Field intField = new Field("id", FieldType.notNullable(new ArrowType.Int(64, true)), null); + assertThat(LanceTypeConverter.isVectorField(intField)).isFalse(); + assertThat(LanceTypeConverter.getVectorDimension(intField)).isEqualTo(-1); + } + + @Test + @DisplayName("Test unsupported type exception") + void testUnsupportedTypeException() { + // Unsupported Arrow type + Field unsupportedField = + new Field("unsupported", FieldType.nullable(new ArrowType.Duration(TimeUnit.SECOND)), null); + + assertThatThrownBy(() -> LanceTypeConverter.arrowTypeToFlinkType(unsupportedField)) + .isInstanceOf(LanceTypeConverter.UnsupportedTypeException.class) + .hasMessageContaining("Unsupported Arrow type"); + } + + @Test + @DisplayName("Test round-trip conversion consistency") + void testRoundTripConversion() { + // Create Flink RowType + List fields = new ArrayList<>(); + fields.add(new RowType.RowField("id", new BigIntType(false))); + fields.add(new RowType.RowField("name", new VarCharType())); + fields.add(new RowType.RowField("score", new DoubleType())); + fields.add(new RowType.RowField("active", new BooleanType())); + + RowType originalRowType = new RowType(fields); + + // Flink -> Arrow -> Flink + Schema arrowSchema = LanceTypeConverter.toArrowSchema(originalRowType); + RowType convertedRowType = LanceTypeConverter.toFlinkRowType(arrowSchema); + + // Verify field count and names + assertThat(convertedRowType.getFieldCount()).isEqualTo(originalRowType.getFieldCount()); + assertThat(convertedRowType.getFieldNames()).isEqualTo(originalRowType.getFieldNames()); + + // Verify types (type classes should match) + for (int i = 0; i < originalRowType.getFieldCount(); i++) { + assertThat(convertedRowType.getTypeAt(i).getClass()) + .isEqualTo(originalRowType.getTypeAt(i).getClass()); } + } } diff --git a/src/test/java/org/apache/flink/connector/lance/LanceVectorSearchTest.java b/src/test/java/org/apache/flink/connector/lance/LanceVectorSearchTest.java index 1f59e60..7a0a971 100644 --- a/src/test/java/org/apache/flink/connector/lance/LanceVectorSearchTest.java +++ b/src/test/java/org/apache/flink/connector/lance/LanceVectorSearchTest.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance; import org.apache.flink.connector.lance.config.LanceOptions; @@ -31,212 +26,224 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -/** - * LanceVectorSearch unit tests. - */ +/** LanceVectorSearch unit tests. */ class LanceVectorSearchTest { - @TempDir - Path tempDir; - - private String datasetPath; - - @BeforeEach - void setUp() { - datasetPath = tempDir.resolve("test_search_dataset").toString(); - } - - @Test - @DisplayName("Test vector search configuration build") - void testVectorSearchConfiguration() { - LanceVectorSearch search = LanceVectorSearch.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .metricType(MetricType.L2) - .nprobes(20) - .ef(100) - .refineFactor(10) - .build(); - - assertThat(search).isNotNull(); - } - - @Test - @DisplayName("Test different metric types") - void testDifferentMetricTypes() { - // L2 distance - LanceVectorSearch l2Search = LanceVectorSearch.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .metricType(MetricType.L2) - .build(); - assertThat(l2Search).isNotNull(); - - // Cosine similarity - LanceVectorSearch cosineSearch = LanceVectorSearch.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .metricType(MetricType.COSINE) - .build(); - assertThat(cosineSearch).isNotNull(); - - // Dot product - LanceVectorSearch dotSearch = LanceVectorSearch.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .metricType(MetricType.DOT) - .build(); - assertThat(dotSearch).isNotNull(); - } - - @Test - @DisplayName("Test exception when missing dataset path") - void testMissingDatasetPath() { - assertThatThrownBy(() -> LanceVectorSearch.builder() - .columnName("embedding") - .metricType(MetricType.L2) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Dataset path cannot be empty"); - } - - @Test - @DisplayName("Test exception when missing column name") - void testMissingColumnName() { - assertThatThrownBy(() -> LanceVectorSearch.builder() - .datasetPath(datasetPath) - .metricType(MetricType.L2) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Column name cannot be empty"); - } - - @Test - @DisplayName("Test invalid nprobes value") - void testInvalidNprobes() { - assertThatThrownBy(() -> LanceVectorSearch.builder() - .datasetPath(datasetPath) - .columnName("embedding") - .nprobes(0) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("nprobes must be greater than 0"); - } - - @Test - @DisplayName("Test default vector search configuration values") - void testDefaultVectorSearchConfiguration() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .vectorColumn("embedding") - .build(); - - // Verify default values - assertThat(options.getVectorMetric()).isEqualTo(MetricType.L2); - assertThat(options.getVectorNprobes()).isEqualTo(20); - assertThat(options.getVectorEf()).isEqualTo(100); - assertThat(options.getVectorRefineFactor()).isNull(); - } - - @Test - @DisplayName("Test creating vector searcher from LanceOptions") - void testFromOptions() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .vectorColumn("embedding") - .vectorMetric(MetricType.COSINE) - .vectorNprobes(30) - .vectorEf(150) - .vectorRefineFactor(5) - .build(); - - LanceVectorSearch search = LanceVectorSearch.fromOptions(options); - - assertThat(search).isNotNull(); - } - - @Test - @DisplayName("Test search result") - void testSearchResult() { - // Create mock RowData - LanceVectorSearch.SearchResult result = new LanceVectorSearch.SearchResult(null, 0.5); - - assertThat(result.getDistance()).isEqualTo(0.5); - assertThat(result.getSimilarity()).isGreaterThan(0); - assertThat(result.getSimilarity()).isLessThanOrEqualTo(1.0); - } - - @Test - @DisplayName("Test search result similarity calculation") - void testSearchResultSimilarity() { - // Distance 0 should have similarity 1.0 - LanceVectorSearch.SearchResult perfectMatch = new LanceVectorSearch.SearchResult(null, 0.0); - assertThat(perfectMatch.getSimilarity()).isEqualTo(1.0); - - // Distance 1 should have similarity 0.5 - LanceVectorSearch.SearchResult halfMatch = new LanceVectorSearch.SearchResult(null, 1.0); - assertThat(halfMatch.getSimilarity()).isEqualTo(0.5); - - // Greater distance means lower similarity - LanceVectorSearch.SearchResult farResult = new LanceVectorSearch.SearchResult(null, 10.0); - assertThat(farResult.getSimilarity()).isLessThan(0.5); - } - - @Test - @DisplayName("Test search result equality") - void testSearchResultEquality() { - LanceVectorSearch.SearchResult result1 = new LanceVectorSearch.SearchResult(null, 0.5); - LanceVectorSearch.SearchResult result2 = new LanceVectorSearch.SearchResult(null, 0.5); - LanceVectorSearch.SearchResult result3 = new LanceVectorSearch.SearchResult(null, 1.0); - - assertThat(result1).isEqualTo(result2); - assertThat(result1.hashCode()).isEqualTo(result2.hashCode()); - assertThat(result1).isNotEqualTo(result3); - } - - @Test - @DisplayName("Test configuration validation - invalid vector search params") - void testInvalidVectorSearchParams() { - assertThatThrownBy(() -> LanceOptions.builder() - .path(datasetPath) - .vectorColumn("embedding") - .vectorNprobes(0) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("nprobes"); - - assertThatThrownBy(() -> LanceOptions.builder() - .path(datasetPath) - .vectorColumn("embedding") - .vectorEf(0) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("ef"); - - assertThatThrownBy(() -> LanceOptions.builder() - .path(datasetPath) - .vectorColumn("embedding") - .vectorRefineFactor(0) - .build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("refine-factor"); - } - - @Test - @DisplayName("Test metric type values") - void testMetricTypeValues() { - assertThat(MetricType.L2.getValue()).isEqualTo("L2"); - assertThat(MetricType.COSINE.getValue()).isEqualTo("Cosine"); - assertThat(MetricType.DOT.getValue()).isEqualTo("Dot"); - } - - @Test - @DisplayName("Test search result toString") - void testSearchResultToString() { - LanceVectorSearch.SearchResult result = new LanceVectorSearch.SearchResult(null, 0.5); - String str = result.toString(); - - assertThat(str).contains("SearchResult"); - assertThat(str).contains("distance=0.5"); - } + @TempDir Path tempDir; + + private String datasetPath; + + @BeforeEach + void setUp() { + datasetPath = tempDir.resolve("test_search_dataset").toString(); + } + + @Test + @DisplayName("Test vector search configuration build") + void testVectorSearchConfiguration() { + LanceVectorSearch search = + LanceVectorSearch.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .metricType(MetricType.L2) + .nprobes(20) + .ef(100) + .refineFactor(10) + .build(); + + assertThat(search).isNotNull(); + } + + @Test + @DisplayName("Test different metric types") + void testDifferentMetricTypes() { + // L2 distance + LanceVectorSearch l2Search = + LanceVectorSearch.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .metricType(MetricType.L2) + .build(); + assertThat(l2Search).isNotNull(); + + // Cosine similarity + LanceVectorSearch cosineSearch = + LanceVectorSearch.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .metricType(MetricType.COSINE) + .build(); + assertThat(cosineSearch).isNotNull(); + + // Dot product + LanceVectorSearch dotSearch = + LanceVectorSearch.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .metricType(MetricType.DOT) + .build(); + assertThat(dotSearch).isNotNull(); + } + + @Test + @DisplayName("Test exception when missing dataset path") + void testMissingDatasetPath() { + assertThatThrownBy( + () -> + LanceVectorSearch.builder() + .columnName("embedding") + .metricType(MetricType.L2) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Dataset path cannot be empty"); + } + + @Test + @DisplayName("Test exception when missing column name") + void testMissingColumnName() { + assertThatThrownBy( + () -> + LanceVectorSearch.builder() + .datasetPath(datasetPath) + .metricType(MetricType.L2) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Column name cannot be empty"); + } + + @Test + @DisplayName("Test invalid nprobes value") + void testInvalidNprobes() { + assertThatThrownBy( + () -> + LanceVectorSearch.builder() + .datasetPath(datasetPath) + .columnName("embedding") + .nprobes(0) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("nprobes must be greater than 0"); + } + + @Test + @DisplayName("Test default vector search configuration values") + void testDefaultVectorSearchConfiguration() { + LanceOptions options = + LanceOptions.builder().path(datasetPath).vectorColumn("embedding").build(); + + // Verify default values + assertThat(options.getVectorMetric()).isEqualTo(MetricType.L2); + assertThat(options.getVectorNprobes()).isEqualTo(20); + assertThat(options.getVectorEf()).isEqualTo(100); + assertThat(options.getVectorRefineFactor()).isNull(); + } + + @Test + @DisplayName("Test creating vector searcher from LanceOptions") + void testFromOptions() { + LanceOptions options = + LanceOptions.builder() + .path(datasetPath) + .vectorColumn("embedding") + .vectorMetric(MetricType.COSINE) + .vectorNprobes(30) + .vectorEf(150) + .vectorRefineFactor(5) + .build(); + + LanceVectorSearch search = LanceVectorSearch.fromOptions(options); + + assertThat(search).isNotNull(); + } + + @Test + @DisplayName("Test search result") + void testSearchResult() { + // Create mock RowData + LanceVectorSearch.SearchResult result = new LanceVectorSearch.SearchResult(null, 0.5); + + assertThat(result.getDistance()).isEqualTo(0.5); + assertThat(result.getSimilarity()).isGreaterThan(0); + assertThat(result.getSimilarity()).isLessThanOrEqualTo(1.0); + } + + @Test + @DisplayName("Test search result similarity calculation") + void testSearchResultSimilarity() { + // Distance 0 should have similarity 1.0 + LanceVectorSearch.SearchResult perfectMatch = new LanceVectorSearch.SearchResult(null, 0.0); + assertThat(perfectMatch.getSimilarity()).isEqualTo(1.0); + + // Distance 1 should have similarity 0.5 + LanceVectorSearch.SearchResult halfMatch = new LanceVectorSearch.SearchResult(null, 1.0); + assertThat(halfMatch.getSimilarity()).isEqualTo(0.5); + + // Greater distance means lower similarity + LanceVectorSearch.SearchResult farResult = new LanceVectorSearch.SearchResult(null, 10.0); + assertThat(farResult.getSimilarity()).isLessThan(0.5); + } + + @Test + @DisplayName("Test search result equality") + void testSearchResultEquality() { + LanceVectorSearch.SearchResult result1 = new LanceVectorSearch.SearchResult(null, 0.5); + LanceVectorSearch.SearchResult result2 = new LanceVectorSearch.SearchResult(null, 0.5); + LanceVectorSearch.SearchResult result3 = new LanceVectorSearch.SearchResult(null, 1.0); + + assertThat(result1).isEqualTo(result2); + assertThat(result1.hashCode()).isEqualTo(result2.hashCode()); + assertThat(result1).isNotEqualTo(result3); + } + + @Test + @DisplayName("Test configuration validation - invalid vector search params") + void testInvalidVectorSearchParams() { + assertThatThrownBy( + () -> + LanceOptions.builder() + .path(datasetPath) + .vectorColumn("embedding") + .vectorNprobes(0) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("nprobes"); + + assertThatThrownBy( + () -> + LanceOptions.builder() + .path(datasetPath) + .vectorColumn("embedding") + .vectorEf(0) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("ef"); + + assertThatThrownBy( + () -> + LanceOptions.builder() + .path(datasetPath) + .vectorColumn("embedding") + .vectorRefineFactor(0) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("refine-factor"); + } + + @Test + @DisplayName("Test metric type values") + void testMetricTypeValues() { + assertThat(MetricType.L2.getValue()).isEqualTo("L2"); + assertThat(MetricType.COSINE.getValue()).isEqualTo("Cosine"); + assertThat(MetricType.DOT.getValue()).isEqualTo("Dot"); + } + + @Test + @DisplayName("Test search result toString") + void testSearchResultToString() { + LanceVectorSearch.SearchResult result = new LanceVectorSearch.SearchResult(null, 0.5); + String str = result.toString(); + + assertThat(str).contains("SearchResult"); + assertThat(str).contains("distance=0.5"); + } } diff --git a/src/test/java/org/apache/flink/connector/lance/aggregate/AggregateExecutorTest.java b/src/test/java/org/apache/flink/connector/lance/aggregate/AggregateExecutorTest.java index 892392c..9cc6bbb 100644 --- a/src/test/java/org/apache/flink/connector/lance/aggregate/AggregateExecutorTest.java +++ b/src/test/java/org/apache/flink/connector/lance/aggregate/AggregateExecutorTest.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.aggregate; import org.apache.flink.table.data.GenericRowData; @@ -37,495 +32,470 @@ import static org.junit.jupiter.api.Assertions.*; -/** - * AggregateExecutor unit tests - */ +/** AggregateExecutor unit tests */ @DisplayName("AggregateExecutor Unit Tests") class AggregateExecutorTest { - private RowType sourceRowType; - - @BeforeEach - void setUp() { - // Create test RowType: (id INT, name VARCHAR, category VARCHAR, amount DOUBLE, quantity INT) - sourceRowType = RowType.of( - new IntType(), - new VarCharType(100), - new VarCharType(50), - new DoubleType(), - new IntType() - ); - sourceRowType = new RowType(Arrays.asList( + private RowType sourceRowType; + + @BeforeEach + void setUp() { + // Create test RowType: (id INT, name VARCHAR, category VARCHAR, amount DOUBLE, quantity INT) + sourceRowType = + RowType.of( + new IntType(), + new VarCharType(100), + new VarCharType(50), + new DoubleType(), + new IntType()); + sourceRowType = + new RowType( + Arrays.asList( new RowType.RowField("id", new IntType()), new RowType.RowField("name", new VarCharType(100)), new RowType.RowField("category", new VarCharType(50)), new RowType.RowField("amount", new DoubleType()), - new RowType.RowField("quantity", new IntType()) - )); + new RowType.RowField("quantity", new IntType()))); + } + + /** Create test data row */ + private RowData createRow(int id, String name, String category, double amount, int quantity) { + GenericRowData row = new GenericRowData(5); + row.setField(0, id); + row.setField(1, StringData.fromString(name)); + row.setField(2, StringData.fromString(category)); + row.setField(3, amount); + row.setField(4, quantity); + return row; + } + + // ==================== COUNT Aggregate Tests ==================== + + @Nested + @DisplayName("COUNT Aggregate Tests") + class CountAggregateTests { + + @Test + @DisplayName("COUNT(*) should correctly count all rows") + void testCountStar() { + AggregateInfo aggInfo = AggregateInfo.builder().addCountStar("cnt").build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + // Accumulate 5 rows of data + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); + executor.accumulate(createRow(4, "David", "B", 180.0, 18)); + executor.accumulate(createRow(5, "Eve", "C", 220.0, 22)); + + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertEquals(5L, results.get(0).getLong(0)); // COUNT(*) } - /** - * Create test data row - */ - private RowData createRow(int id, String name, String category, double amount, int quantity) { - GenericRowData row = new GenericRowData(5); - row.setField(0, id); - row.setField(1, StringData.fromString(name)); - row.setField(2, StringData.fromString(category)); - row.setField(3, amount); - row.setField(4, quantity); - return row; + @Test + @DisplayName("COUNT(column) should correctly count non-null values") + void testCountColumn() { + AggregateInfo aggInfo = AggregateInfo.builder().addCount("name", "name_count").build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); + + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertEquals(3L, results.get(0).getLong(0)); } - // ==================== COUNT Aggregate Tests ==================== - - @Nested - @DisplayName("COUNT Aggregate Tests") - class CountAggregateTests { - - @Test - @DisplayName("COUNT(*) should correctly count all rows") - void testCountStar() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - // Accumulate 5 rows of data - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); - executor.accumulate(createRow(4, "David", "B", 180.0, 18)); - executor.accumulate(createRow(5, "Eve", "C", 220.0, 22)); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertEquals(5L, results.get(0).getLong(0)); // COUNT(*) - } + @Test + @DisplayName("COUNT(*) on empty dataset should return 0") + void testCountStarEmpty() { + AggregateInfo aggInfo = AggregateInfo.builder().addCountStar("cnt").build(); - @Test - @DisplayName("COUNT(column) should correctly count non-null values") - void testCountColumn() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCount("name", "name_count") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertEquals(3L, results.get(0).getLong(0)); - } + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); - @Test - @DisplayName("COUNT(*) on empty dataset should return 0") - void testCountStarEmpty() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertEquals(0L, results.get(0).getLong(0)); - } + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertEquals(0L, results.get(0).getLong(0)); } + } - // ==================== SUM Aggregate Tests ==================== - - @Nested - @DisplayName("SUM Aggregate Tests") - class SumAggregateTests { - - @Test - @DisplayName("SUM should correctly sum values") - void testSum() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addSum("amount", "total_amount") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertEquals(450.0, results.get(0).getDouble(0), 0.001); - } + // ==================== SUM Aggregate Tests ==================== - @Test - @DisplayName("SUM on empty dataset should return null") - void testSumEmpty() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addSum("amount", "total_amount") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertTrue(results.get(0).isNullAt(0)); - } + @Nested + @DisplayName("SUM Aggregate Tests") + class SumAggregateTests { + + @Test + @DisplayName("SUM should correctly sum values") + void testSum() { + AggregateInfo aggInfo = AggregateInfo.builder().addSum("amount", "total_amount").build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); + + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertEquals(450.0, results.get(0).getDouble(0), 0.001); } - // ==================== AVG Aggregate Tests ==================== - - @Nested - @DisplayName("AVG Aggregate Tests") - class AvgAggregateTests { - - @Test - @DisplayName("AVG should correctly calculate average") - void testAvg() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addAvg("amount", "avg_amount") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertEquals(150.0, results.get(0).getDouble(0), 0.001); // (100+200+150)/3 - } + @Test + @DisplayName("SUM on empty dataset should return null") + void testSumEmpty() { + AggregateInfo aggInfo = AggregateInfo.builder().addSum("amount", "total_amount").build(); - @Test - @DisplayName("AVG on empty dataset should return null") - void testAvgEmpty() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addAvg("amount", "avg_amount") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertTrue(results.get(0).isNullAt(0)); - } + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertTrue(results.get(0).isNullAt(0)); } + } - // ==================== MIN/MAX Aggregate Tests ==================== - - @Nested - @DisplayName("MIN/MAX Aggregate Tests") - class MinMaxAggregateTests { - - @Test - @DisplayName("MIN should return minimum value") - void testMin() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addMin("amount", "min_amount") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - executor.accumulate(createRow(3, "Charlie", "A", 50.0, 15)); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertEquals(50.0, results.get(0).getDouble(0), 0.001); - } + // ==================== AVG Aggregate Tests ==================== - @Test - @DisplayName("MAX should return maximum value") - void testMax() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addMax("amount", "max_amount") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - executor.accumulate(createRow(3, "Charlie", "A", 50.0, 15)); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertEquals(200.0, results.get(0).getDouble(0), 0.001); - } + @Nested + @DisplayName("AVG Aggregate Tests") + class AvgAggregateTests { - @Test - @DisplayName("MIN/MAX on empty dataset should return null") - void testMinMaxEmpty() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addMin("amount", "min_amount") - .addMax("amount", "max_amount") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertTrue(results.get(0).isNullAt(0)); // MIN - assertTrue(results.get(0).isNullAt(1)); // MAX - } + @Test + @DisplayName("AVG should correctly calculate average") + void testAvg() { + AggregateInfo aggInfo = AggregateInfo.builder().addAvg("amount", "avg_amount").build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); + + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertEquals(150.0, results.get(0).getDouble(0), 0.001); // (100+200+150)/3 } - // ==================== GROUP BY Tests ==================== - - @Nested - @DisplayName("GROUP BY Aggregate Tests") - class GroupByAggregateTests { - - @Test - @DisplayName("COUNT with GROUP BY should count by group") - void testGroupByCount() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .groupBy("category") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); - executor.accumulate(createRow(4, "David", "B", 180.0, 18)); - executor.accumulate(createRow(5, "Eve", "A", 220.0, 22)); - - List results = executor.getResults(); - - assertEquals(2, results.size()); // 2 groups: A and B - - // Verify count for each group - long countA = 0, countB = 0; - for (RowData row : results) { - String category = row.getString(0).toString(); - long count = row.getLong(1); - if ("A".equals(category)) { - countA = count; - } else if ("B".equals(category)) { - countB = count; - } - } - assertEquals(3, countA); // A has 3 rows - assertEquals(2, countB); // B has 2 rows - } + @Test + @DisplayName("AVG on empty dataset should return null") + void testAvgEmpty() { + AggregateInfo aggInfo = AggregateInfo.builder().addAvg("amount", "avg_amount").build(); - @Test - @DisplayName("SUM with GROUP BY should sum by group") - void testGroupBySum() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addSum("amount", "total_amount") - .groupBy("category") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); - - List results = executor.getResults(); - - assertEquals(2, results.size()); - - // Verify sum for each group - for (RowData row : results) { - String category = row.getString(0).toString(); - double sum = row.getDouble(1); - if ("A".equals(category)) { - assertEquals(250.0, sum, 0.001); // 100 + 150 - } else if ("B".equals(category)) { - assertEquals(200.0, sum, 0.001); // 200 - } - } - } + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); - @Test - @DisplayName("Empty dataset with GROUP BY should return empty result") - void testGroupByEmpty() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .groupBy("category") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - List results = executor.getResults(); - - assertTrue(results.isEmpty()); - } + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertTrue(results.get(0).isNullAt(0)); } + } - // ==================== Multiple Aggregates Tests ==================== - - @Nested - @DisplayName("Multiple Aggregates Tests") - class MultipleAggregatesTests { - - @Test - @DisplayName("Multiple aggregate functions should work together") - void testMultipleAggregates() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .addSum("amount", "sum_amount") - .addAvg("amount", "avg_amount") - .addMin("amount", "min_amount") - .addMax("amount", "max_amount") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - RowData result = results.get(0); - - assertEquals(3L, result.getLong(0)); // COUNT(*) - assertEquals(450.0, result.getDouble(1), 0.001); // SUM - assertEquals(150.0, result.getDouble(2), 0.001); // AVG - assertEquals(100.0, result.getDouble(3), 0.001); // MIN - assertEquals(200.0, result.getDouble(4), 0.001); // MAX - } + // ==================== MIN/MAX Aggregate Tests ==================== + + @Nested + @DisplayName("MIN/MAX Aggregate Tests") + class MinMaxAggregateTests { + + @Test + @DisplayName("MIN should return minimum value") + void testMin() { + AggregateInfo aggInfo = AggregateInfo.builder().addMin("amount", "min_amount").build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + executor.accumulate(createRow(3, "Charlie", "A", 50.0, 15)); + + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertEquals(50.0, results.get(0).getDouble(0), 0.001); + } + + @Test + @DisplayName("MAX should return maximum value") + void testMax() { + AggregateInfo aggInfo = AggregateInfo.builder().addMax("amount", "max_amount").build(); - @Test - @DisplayName("Multiple aggregates with GROUP BY should work correctly") - void testMultipleAggregatesWithGroupBy() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .addSum("amount", "sum_amount") - .addAvg("amount", "avg_amount") - .groupBy("category") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - executor.accumulate(createRow(3, "Charlie", "A", 200.0, 15)); - - List results = executor.getResults(); - - assertEquals(2, results.size()); - - for (RowData row : results) { - String category = row.getString(0).toString(); - long count = row.getLong(1); - double sum = row.getDouble(2); - double avg = row.getDouble(3); - - if ("A".equals(category)) { - assertEquals(2, count); - assertEquals(300.0, sum, 0.001); - assertEquals(150.0, avg, 0.001); - } else if ("B".equals(category)) { - assertEquals(1, count); - assertEquals(200.0, sum, 0.001); - assertEquals(200.0, avg, 0.001); - } - } + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + executor.accumulate(createRow(3, "Charlie", "A", 50.0, 15)); + + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertEquals(200.0, results.get(0).getDouble(0), 0.001); + } + + @Test + @DisplayName("MIN/MAX on empty dataset should return null") + void testMinMaxEmpty() { + AggregateInfo aggInfo = + AggregateInfo.builder() + .addMin("amount", "min_amount") + .addMax("amount", "max_amount") + .build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertTrue(results.get(0).isNullAt(0)); // MIN + assertTrue(results.get(0).isNullAt(1)); // MAX + } + } + + // ==================== GROUP BY Tests ==================== + + @Nested + @DisplayName("GROUP BY Aggregate Tests") + class GroupByAggregateTests { + + @Test + @DisplayName("COUNT with GROUP BY should count by group") + void testGroupByCount() { + AggregateInfo aggInfo = + AggregateInfo.builder().addCountStar("cnt").groupBy("category").build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); + executor.accumulate(createRow(4, "David", "B", 180.0, 18)); + executor.accumulate(createRow(5, "Eve", "A", 220.0, 22)); + + List results = executor.getResults(); + + assertEquals(2, results.size()); // 2 groups: A and B + + // Verify count for each group + long countA = 0, countB = 0; + for (RowData row : results) { + String category = row.getString(0).toString(); + long count = row.getLong(1); + if ("A".equals(category)) { + countA = count; + } else if ("B".equals(category)) { + countB = count; } + } + assertEquals(3, countA); // A has 3 rows + assertEquals(2, countB); // B has 2 rows } - // ==================== Reset Tests ==================== - - @Nested - @DisplayName("Reset Tests") - class ResetTests { - - @Test - @DisplayName("reset should clear aggregate state") - void testReset() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); - executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); - - // Reset - executor.reset(); - - // Re-initialize and accumulate new data - executor.init(); - executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); - - List results = executor.getResults(); - - assertEquals(1, results.size()); - assertEquals(1L, results.get(0).getLong(0)); // Only 1 row after reset + @Test + @DisplayName("SUM with GROUP BY should sum by group") + void testGroupBySum() { + AggregateInfo aggInfo = + AggregateInfo.builder().addSum("amount", "total_amount").groupBy("category").build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); + + List results = executor.getResults(); + + assertEquals(2, results.size()); + + // Verify sum for each group + for (RowData row : results) { + String category = row.getString(0).toString(); + double sum = row.getDouble(1); + if ("A".equals(category)) { + assertEquals(250.0, sum, 0.001); // 100 + 150 + } else if ("B".equals(category)) { + assertEquals(200.0, sum, 0.001); // 200 } + } + } + + @Test + @DisplayName("Empty dataset with GROUP BY should return empty result") + void testGroupByEmpty() { + AggregateInfo aggInfo = + AggregateInfo.builder().addCountStar("cnt").groupBy("category").build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + List results = executor.getResults(); + + assertTrue(results.isEmpty()); + } + } + + // ==================== Multiple Aggregates Tests ==================== + + @Nested + @DisplayName("Multiple Aggregates Tests") + class MultipleAggregatesTests { + + @Test + @DisplayName("Multiple aggregate functions should work together") + void testMultipleAggregates() { + AggregateInfo aggInfo = + AggregateInfo.builder() + .addCountStar("cnt") + .addSum("amount", "sum_amount") + .addAvg("amount", "avg_amount") + .addMin("amount", "min_amount") + .addMax("amount", "max_amount") + .build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); + + List results = executor.getResults(); + + assertEquals(1, results.size()); + RowData result = results.get(0); + + assertEquals(3L, result.getLong(0)); // COUNT(*) + assertEquals(450.0, result.getDouble(1), 0.001); // SUM + assertEquals(150.0, result.getDouble(2), 0.001); // AVG + assertEquals(100.0, result.getDouble(3), 0.001); // MIN + assertEquals(200.0, result.getDouble(4), 0.001); // MAX } - // ==================== Result Type Tests ==================== - - @Nested - @DisplayName("Result Type Tests") - class ResultTypeTests { - - @Test - @DisplayName("buildResultRowType should return correct result type") - void testBuildResultRowType() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .addSum("amount", "sum_amount") - .groupBy("category") - .build(); - - AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); - executor.init(); - - RowType resultType = executor.buildResultRowType(); - - assertNotNull(resultType); - assertEquals(3, resultType.getFieldCount()); - - // First field is group column category - assertEquals("category", resultType.getFieldNames().get(0)); - - // Second field is COUNT result - assertEquals("cnt", resultType.getFieldNames().get(1)); - assertTrue(resultType.getTypeAt(1) instanceof BigIntType); - - // Third field is SUM result - assertEquals("sum_amount", resultType.getFieldNames().get(2)); - assertTrue(resultType.getTypeAt(2) instanceof DoubleType); + @Test + @DisplayName("Multiple aggregates with GROUP BY should work correctly") + void testMultipleAggregatesWithGroupBy() { + AggregateInfo aggInfo = + AggregateInfo.builder() + .addCountStar("cnt") + .addSum("amount", "sum_amount") + .addAvg("amount", "avg_amount") + .groupBy("category") + .build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + executor.accumulate(createRow(3, "Charlie", "A", 200.0, 15)); + + List results = executor.getResults(); + + assertEquals(2, results.size()); + + for (RowData row : results) { + String category = row.getString(0).toString(); + long count = row.getLong(1); + double sum = row.getDouble(2); + double avg = row.getDouble(3); + + if ("A".equals(category)) { + assertEquals(2, count); + assertEquals(300.0, sum, 0.001); + assertEquals(150.0, avg, 0.001); + } else if ("B".equals(category)) { + assertEquals(1, count); + assertEquals(200.0, sum, 0.001); + assertEquals(200.0, avg, 0.001); } + } + } + } + + // ==================== Reset Tests ==================== + + @Nested + @DisplayName("Reset Tests") + class ResetTests { + + @Test + @DisplayName("reset should clear aggregate state") + void testReset() { + AggregateInfo aggInfo = AggregateInfo.builder().addCountStar("cnt").build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + executor.accumulate(createRow(1, "Alice", "A", 100.0, 10)); + executor.accumulate(createRow(2, "Bob", "B", 200.0, 20)); + + // Reset + executor.reset(); + + // Re-initialize and accumulate new data + executor.init(); + executor.accumulate(createRow(3, "Charlie", "A", 150.0, 15)); + + List results = executor.getResults(); + + assertEquals(1, results.size()); + assertEquals(1L, results.get(0).getLong(0)); // Only 1 row after reset + } + } + + // ==================== Result Type Tests ==================== + + @Nested + @DisplayName("Result Type Tests") + class ResultTypeTests { + + @Test + @DisplayName("buildResultRowType should return correct result type") + void testBuildResultRowType() { + AggregateInfo aggInfo = + AggregateInfo.builder() + .addCountStar("cnt") + .addSum("amount", "sum_amount") + .groupBy("category") + .build(); + + AggregateExecutor executor = new AggregateExecutor(aggInfo, sourceRowType); + executor.init(); + + RowType resultType = executor.buildResultRowType(); + + assertNotNull(resultType); + assertEquals(3, resultType.getFieldCount()); + + // First field is group column category + assertEquals("category", resultType.getFieldNames().get(0)); + + // Second field is COUNT result + assertEquals("cnt", resultType.getFieldNames().get(1)); + assertTrue(resultType.getTypeAt(1) instanceof BigIntType); + + // Third field is SUM result + assertEquals("sum_amount", resultType.getFieldNames().get(2)); + assertTrue(resultType.getTypeAt(2) instanceof DoubleType); } + } } diff --git a/src/test/java/org/apache/flink/connector/lance/aggregate/AggregateInfoTest.java b/src/test/java/org/apache/flink/connector/lance/aggregate/AggregateInfoTest.java index 08f3e9b..f5e8615 100644 --- a/src/test/java/org/apache/flink/connector/lance/aggregate/AggregateInfoTest.java +++ b/src/test/java/org/apache/flink/connector/lance/aggregate/AggregateInfoTest.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.aggregate; import org.junit.jupiter.api.DisplayName; @@ -27,330 +22,320 @@ import static org.junit.jupiter.api.Assertions.*; -/** - * AggregateInfo unit tests - */ +/** AggregateInfo unit tests */ @DisplayName("AggregateInfo Unit Tests") class AggregateInfoTest { - // ==================== AggregateCall Tests ==================== - - @Nested - @DisplayName("AggregateCall Tests") - class AggregateCallTests { - - @Test - @DisplayName("COUNT(*) should be correctly identified") - void testCountStar() { - AggregateInfo.AggregateCall call = new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.COUNT, null, "cnt"); - - assertTrue(call.isCountStar()); - assertEquals(AggregateInfo.AggregateFunction.COUNT, call.getFunction()); - assertNull(call.getColumn()); - assertEquals("cnt", call.getAlias()); - assertEquals("COUNT(*)", call.toString()); - } - - @Test - @DisplayName("COUNT(column) should be correctly identified") - void testCountColumn() { - AggregateInfo.AggregateCall call = new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.COUNT, "id", "id_count"); - - assertFalse(call.isCountStar()); - assertEquals(AggregateInfo.AggregateFunction.COUNT, call.getFunction()); - assertEquals("id", call.getColumn()); - assertEquals("id_count", call.getAlias()); - assertEquals("COUNT(id)", call.toString()); - } - - @Test - @DisplayName("SUM aggregate should be correctly built") - void testSumAggregate() { - AggregateInfo.AggregateCall call = new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.SUM, "amount", "total_amount"); - - assertFalse(call.isCountStar()); - assertEquals(AggregateInfo.AggregateFunction.SUM, call.getFunction()); - assertEquals("amount", call.getColumn()); - assertEquals("total_amount", call.getAlias()); - assertEquals("SUM(amount)", call.toString()); - } - - @Test - @DisplayName("AVG aggregate should be correctly built") - void testAvgAggregate() { - AggregateInfo.AggregateCall call = new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.AVG, "score", "avg_score"); - - assertEquals(AggregateInfo.AggregateFunction.AVG, call.getFunction()); - assertEquals("score", call.getColumn()); - assertEquals("avg_score", call.getAlias()); - assertEquals("AVG(score)", call.toString()); - } - - @Test - @DisplayName("MIN aggregate should be correctly built") - void testMinAggregate() { - AggregateInfo.AggregateCall call = new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.MIN, "price", "min_price"); - - assertEquals(AggregateInfo.AggregateFunction.MIN, call.getFunction()); - assertEquals("price", call.getColumn()); - assertEquals("min_price", call.getAlias()); - } - - @Test - @DisplayName("MAX aggregate should be correctly built") - void testMaxAggregate() { - AggregateInfo.AggregateCall call = new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.MAX, "price", "max_price"); - - assertEquals(AggregateInfo.AggregateFunction.MAX, call.getFunction()); - assertEquals("price", call.getColumn()); - assertEquals("max_price", call.getAlias()); - } - - @Test - @DisplayName("AggregateCall equals and hashCode should work correctly") - void testAggregateCallEqualsAndHashCode() { - AggregateInfo.AggregateCall call1 = new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.SUM, "amount", "total"); - AggregateInfo.AggregateCall call2 = new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.SUM, "amount", "total"); - AggregateInfo.AggregateCall call3 = new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.SUM, "price", "total"); - - assertEquals(call1, call2); - assertEquals(call1.hashCode(), call2.hashCode()); - assertNotEquals(call1, call3); - } + // ==================== AggregateCall Tests ==================== + + @Nested + @DisplayName("AggregateCall Tests") + class AggregateCallTests { + + @Test + @DisplayName("COUNT(*) should be correctly identified") + void testCountStar() { + AggregateInfo.AggregateCall call = + new AggregateInfo.AggregateCall(AggregateInfo.AggregateFunction.COUNT, null, "cnt"); + + assertTrue(call.isCountStar()); + assertEquals(AggregateInfo.AggregateFunction.COUNT, call.getFunction()); + assertNull(call.getColumn()); + assertEquals("cnt", call.getAlias()); + assertEquals("COUNT(*)", call.toString()); + } + + @Test + @DisplayName("COUNT(column) should be correctly identified") + void testCountColumn() { + AggregateInfo.AggregateCall call = + new AggregateInfo.AggregateCall(AggregateInfo.AggregateFunction.COUNT, "id", "id_count"); + + assertFalse(call.isCountStar()); + assertEquals(AggregateInfo.AggregateFunction.COUNT, call.getFunction()); + assertEquals("id", call.getColumn()); + assertEquals("id_count", call.getAlias()); + assertEquals("COUNT(id)", call.toString()); + } + + @Test + @DisplayName("SUM aggregate should be correctly built") + void testSumAggregate() { + AggregateInfo.AggregateCall call = + new AggregateInfo.AggregateCall( + AggregateInfo.AggregateFunction.SUM, "amount", "total_amount"); + + assertFalse(call.isCountStar()); + assertEquals(AggregateInfo.AggregateFunction.SUM, call.getFunction()); + assertEquals("amount", call.getColumn()); + assertEquals("total_amount", call.getAlias()); + assertEquals("SUM(amount)", call.toString()); + } + + @Test + @DisplayName("AVG aggregate should be correctly built") + void testAvgAggregate() { + AggregateInfo.AggregateCall call = + new AggregateInfo.AggregateCall( + AggregateInfo.AggregateFunction.AVG, "score", "avg_score"); + + assertEquals(AggregateInfo.AggregateFunction.AVG, call.getFunction()); + assertEquals("score", call.getColumn()); + assertEquals("avg_score", call.getAlias()); + assertEquals("AVG(score)", call.toString()); + } + + @Test + @DisplayName("MIN aggregate should be correctly built") + void testMinAggregate() { + AggregateInfo.AggregateCall call = + new AggregateInfo.AggregateCall( + AggregateInfo.AggregateFunction.MIN, "price", "min_price"); + + assertEquals(AggregateInfo.AggregateFunction.MIN, call.getFunction()); + assertEquals("price", call.getColumn()); + assertEquals("min_price", call.getAlias()); + } + + @Test + @DisplayName("MAX aggregate should be correctly built") + void testMaxAggregate() { + AggregateInfo.AggregateCall call = + new AggregateInfo.AggregateCall( + AggregateInfo.AggregateFunction.MAX, "price", "max_price"); + + assertEquals(AggregateInfo.AggregateFunction.MAX, call.getFunction()); + assertEquals("price", call.getColumn()); + assertEquals("max_price", call.getAlias()); + } + + @Test + @DisplayName("AggregateCall equals and hashCode should work correctly") + void testAggregateCallEqualsAndHashCode() { + AggregateInfo.AggregateCall call1 = + new AggregateInfo.AggregateCall(AggregateInfo.AggregateFunction.SUM, "amount", "total"); + AggregateInfo.AggregateCall call2 = + new AggregateInfo.AggregateCall(AggregateInfo.AggregateFunction.SUM, "amount", "total"); + AggregateInfo.AggregateCall call3 = + new AggregateInfo.AggregateCall(AggregateInfo.AggregateFunction.SUM, "price", "total"); + + assertEquals(call1, call2); + assertEquals(call1.hashCode(), call2.hashCode()); + assertNotEquals(call1, call3); + } + } + + // ==================== AggregateInfo Builder Tests ==================== + + @Nested + @DisplayName("AggregateInfo Builder Tests") + class AggregateInfoBuilderTests { + + @Test + @DisplayName("Build simple COUNT(*) query") + void testBuildSimpleCountStar() { + AggregateInfo info = AggregateInfo.builder().addCountStar("cnt").build(); + + assertNotNull(info); + assertEquals(1, info.getAggregateCalls().size()); + assertTrue(info.isSimpleCountStar()); + assertFalse(info.hasGroupBy()); + } + + @Test + @DisplayName("Build aggregate query with GROUP BY") + void testBuildAggregateWithGroupBy() { + AggregateInfo info = + AggregateInfo.builder() + .addSum("amount", "total_amount") + .addAvg("score", "avg_score") + .groupBy("category", "region") + .build(); + + assertNotNull(info); + assertEquals(2, info.getAggregateCalls().size()); + assertTrue(info.hasGroupBy()); + assertEquals(Arrays.asList("category", "region"), info.getGroupByColumns()); + assertFalse(info.isSimpleCountStar()); + } + + @Test + @DisplayName("Build multiple aggregates query") + void testBuildMultipleAggregates() { + AggregateInfo info = + AggregateInfo.builder() + .addCountStar("cnt") + .addSum("amount", "sum_amount") + .addAvg("score", "avg_score") + .addMin("price", "min_price") + .addMax("price", "max_price") + .build(); + + assertNotNull(info); + assertEquals(5, info.getAggregateCalls().size()); + assertFalse(info.hasGroupBy()); + } + + @Test + @DisplayName("Build requires at least one aggregate function") + void testBuildRequiresAtLeastOneAggregate() { + assertThrows( + IllegalArgumentException.class, + () -> { + AggregateInfo.builder().build(); + }); } - // ==================== AggregateInfo Builder Tests ==================== - - @Nested - @DisplayName("AggregateInfo Builder Tests") - class AggregateInfoBuilderTests { - - @Test - @DisplayName("Build simple COUNT(*) query") - void testBuildSimpleCountStar() { - AggregateInfo info = AggregateInfo.builder() - .addCountStar("cnt") - .build(); - - assertNotNull(info); - assertEquals(1, info.getAggregateCalls().size()); - assertTrue(info.isSimpleCountStar()); - assertFalse(info.hasGroupBy()); - } - - @Test - @DisplayName("Build aggregate query with GROUP BY") - void testBuildAggregateWithGroupBy() { - AggregateInfo info = AggregateInfo.builder() - .addSum("amount", "total_amount") - .addAvg("score", "avg_score") - .groupBy("category", "region") - .build(); - - assertNotNull(info); - assertEquals(2, info.getAggregateCalls().size()); - assertTrue(info.hasGroupBy()); - assertEquals(Arrays.asList("category", "region"), info.getGroupByColumns()); - assertFalse(info.isSimpleCountStar()); - } - - @Test - @DisplayName("Build multiple aggregates query") - void testBuildMultipleAggregates() { - AggregateInfo info = AggregateInfo.builder() - .addCountStar("cnt") - .addSum("amount", "sum_amount") - .addAvg("score", "avg_score") - .addMin("price", "min_price") - .addMax("price", "max_price") - .build(); - - assertNotNull(info); - assertEquals(5, info.getAggregateCalls().size()); - assertFalse(info.hasGroupBy()); - } - - @Test - @DisplayName("Build requires at least one aggregate function") - void testBuildRequiresAtLeastOneAggregate() { - assertThrows(IllegalArgumentException.class, () -> { - AggregateInfo.builder().build(); - }); - } - - @Test - @DisplayName("addAggregateCall should work correctly") - void testAddAggregateCall() { - AggregateInfo.AggregateCall call = new AggregateInfo.AggregateCall( - AggregateInfo.AggregateFunction.SUM, "amount", "total"); - - AggregateInfo info = AggregateInfo.builder() - .addAggregateCall(call) - .build(); - - assertEquals(1, info.getAggregateCalls().size()); - assertEquals(call, info.getAggregateCalls().get(0)); - } - - @Test - @DisplayName("addCount should work correctly") - void testAddCount() { - AggregateInfo info = AggregateInfo.builder() - .addCount("id", "id_count") - .build(); - - AggregateInfo.AggregateCall call = info.getAggregateCalls().get(0); - assertEquals(AggregateInfo.AggregateFunction.COUNT, call.getFunction()); - assertEquals("id", call.getColumn()); - assertFalse(call.isCountStar()); - } - - @Test - @DisplayName("groupBy(List) should work correctly") - void testGroupByWithList() { - List groupCols = Arrays.asList("col1", "col2", "col3"); - - AggregateInfo info = AggregateInfo.builder() - .addCountStar("cnt") - .groupBy(groupCols) - .build(); - - assertEquals(groupCols, info.getGroupByColumns()); - } - - @Test - @DisplayName("groupByFieldIndices should be correctly set") - void testGroupByFieldIndices() { - int[] indices = {0, 2, 4}; - - AggregateInfo info = AggregateInfo.builder() - .addCountStar("cnt") - .groupBy("col1", "col3", "col5") - .groupByFieldIndices(indices) - .build(); - - assertArrayEquals(indices, info.getGroupByFieldIndices()); - } + @Test + @DisplayName("addAggregateCall should work correctly") + void testAddAggregateCall() { + AggregateInfo.AggregateCall call = + new AggregateInfo.AggregateCall(AggregateInfo.AggregateFunction.SUM, "amount", "total"); + + AggregateInfo info = AggregateInfo.builder().addAggregateCall(call).build(); + + assertEquals(1, info.getAggregateCalls().size()); + assertEquals(call, info.getAggregateCalls().get(0)); + } + + @Test + @DisplayName("addCount should work correctly") + void testAddCount() { + AggregateInfo info = AggregateInfo.builder().addCount("id", "id_count").build(); + + AggregateInfo.AggregateCall call = info.getAggregateCalls().get(0); + assertEquals(AggregateInfo.AggregateFunction.COUNT, call.getFunction()); + assertEquals("id", call.getColumn()); + assertFalse(call.isCountStar()); + } + + @Test + @DisplayName("groupBy(List) should work correctly") + void testGroupByWithList() { + List groupCols = Arrays.asList("col1", "col2", "col3"); + + AggregateInfo info = AggregateInfo.builder().addCountStar("cnt").groupBy(groupCols).build(); + + assertEquals(groupCols, info.getGroupByColumns()); + } + + @Test + @DisplayName("groupByFieldIndices should be correctly set") + void testGroupByFieldIndices() { + int[] indices = {0, 2, 4}; + + AggregateInfo info = + AggregateInfo.builder() + .addCountStar("cnt") + .groupBy("col1", "col3", "col5") + .groupByFieldIndices(indices) + .build(); + + assertArrayEquals(indices, info.getGroupByFieldIndices()); + } + } + + // ==================== AggregateInfo Methods Tests ==================== + + @Nested + @DisplayName("AggregateInfo Methods Tests") + class AggregateInfoMethodTests { + + @Test + @DisplayName("getRequiredColumns should return all required columns") + void testGetRequiredColumns() { + AggregateInfo info = + AggregateInfo.builder() + .addSum("amount", "sum_amount") + .addAvg("score", "avg_score") + .groupBy("category", "region") + .build(); + + List required = info.getRequiredColumns(); + + // Should contain group columns and aggregate columns + assertTrue(required.contains("category")); + assertTrue(required.contains("region")); + assertTrue(required.contains("amount")); + assertTrue(required.contains("score")); + } + + @Test + @DisplayName("getRequiredColumns should deduplicate") + void testGetRequiredColumnsDedup() { + AggregateInfo info = + AggregateInfo.builder() + .addSum("amount", "sum_amount") + .addAvg("amount", "avg_amount") // Same column + .groupBy("category") + .build(); + + List required = info.getRequiredColumns(); + + // amount should appear only once + long amountCount = required.stream().filter(c -> c.equals("amount")).count(); + assertEquals(1, amountCount); + } + + @Test + @DisplayName("COUNT(*) does not require column") + void testCountStarNoColumn() { + AggregateInfo info = AggregateInfo.builder().addCountStar("cnt").build(); + + List required = info.getRequiredColumns(); + assertTrue(required.isEmpty()); + } + + @Test + @DisplayName("equals and hashCode should work correctly") + void testEqualsAndHashCode() { + AggregateInfo info1 = + AggregateInfo.builder().addSum("amount", "total").groupBy("category").build(); + + AggregateInfo info2 = + AggregateInfo.builder().addSum("amount", "total").groupBy("category").build(); + + AggregateInfo info3 = + AggregateInfo.builder().addAvg("amount", "avg").groupBy("category").build(); + + assertEquals(info1, info2); + assertEquals(info1.hashCode(), info2.hashCode()); + assertNotEquals(info1, info3); } - // ==================== AggregateInfo Methods Tests ==================== - - @Nested - @DisplayName("AggregateInfo Methods Tests") - class AggregateInfoMethodTests { - - @Test - @DisplayName("getRequiredColumns should return all required columns") - void testGetRequiredColumns() { - AggregateInfo info = AggregateInfo.builder() - .addSum("amount", "sum_amount") - .addAvg("score", "avg_score") - .groupBy("category", "region") - .build(); - - List required = info.getRequiredColumns(); - - // Should contain group columns and aggregate columns - assertTrue(required.contains("category")); - assertTrue(required.contains("region")); - assertTrue(required.contains("amount")); - assertTrue(required.contains("score")); - } - - @Test - @DisplayName("getRequiredColumns should deduplicate") - void testGetRequiredColumnsDedup() { - AggregateInfo info = AggregateInfo.builder() - .addSum("amount", "sum_amount") - .addAvg("amount", "avg_amount") // Same column - .groupBy("category") - .build(); - - List required = info.getRequiredColumns(); - - // amount should appear only once - long amountCount = required.stream().filter(c -> c.equals("amount")).count(); - assertEquals(1, amountCount); - } - - @Test - @DisplayName("COUNT(*) does not require column") - void testCountStarNoColumn() { - AggregateInfo info = AggregateInfo.builder() - .addCountStar("cnt") - .build(); - - List required = info.getRequiredColumns(); - assertTrue(required.isEmpty()); - } - - @Test - @DisplayName("equals and hashCode should work correctly") - void testEqualsAndHashCode() { - AggregateInfo info1 = AggregateInfo.builder() - .addSum("amount", "total") - .groupBy("category") - .build(); - - AggregateInfo info2 = AggregateInfo.builder() - .addSum("amount", "total") - .groupBy("category") - .build(); - - AggregateInfo info3 = AggregateInfo.builder() - .addAvg("amount", "avg") - .groupBy("category") - .build(); - - assertEquals(info1, info2); - assertEquals(info1.hashCode(), info2.hashCode()); - assertNotEquals(info1, info3); - } - - @Test - @DisplayName("toString should return meaningful string") - void testToString() { - AggregateInfo info = AggregateInfo.builder() - .addSum("amount", "total") - .groupBy("category") - .build(); - - String str = info.toString(); - - assertTrue(str.contains("AggregateInfo")); - assertTrue(str.contains("SUM(amount)")); - assertTrue(str.contains("groupBy")); - assertTrue(str.contains("category")); - } + @Test + @DisplayName("toString should return meaningful string") + void testToString() { + AggregateInfo info = + AggregateInfo.builder().addSum("amount", "total").groupBy("category").build(); + + String str = info.toString(); + + assertTrue(str.contains("AggregateInfo")); + assertTrue(str.contains("SUM(amount)")); + assertTrue(str.contains("groupBy")); + assertTrue(str.contains("category")); } + } + + // ==================== Aggregate Function Enum Tests ==================== + + @Nested + @DisplayName("AggregateFunction Enum Tests") + class AggregateFunctionEnumTests { + + @Test + @DisplayName("Should contain all supported aggregate functions") + void testAllAggregateFunctions() { + AggregateInfo.AggregateFunction[] functions = AggregateInfo.AggregateFunction.values(); - // ==================== Aggregate Function Enum Tests ==================== - - @Nested - @DisplayName("AggregateFunction Enum Tests") - class AggregateFunctionEnumTests { - - @Test - @DisplayName("Should contain all supported aggregate functions") - void testAllAggregateFunctions() { - AggregateInfo.AggregateFunction[] functions = AggregateInfo.AggregateFunction.values(); - - assertEquals(6, functions.length); - assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.COUNT)); - assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.COUNT_DISTINCT)); - assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.SUM)); - assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.AVG)); - assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.MIN)); - assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.MAX)); - } + assertEquals(6, functions.length); + assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.COUNT)); + assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.COUNT_DISTINCT)); + assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.SUM)); + assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.AVG)); + assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.MIN)); + assertTrue(Arrays.asList(functions).contains(AggregateInfo.AggregateFunction.MAX)); } + } } diff --git a/src/test/java/org/apache/flink/connector/lance/table/FlinkSqlDemo.java b/src/test/java/org/apache/flink/connector/lance/table/FlinkSqlDemo.java index 6d7603a..3974220 100644 --- a/src/test/java/org/apache/flink/connector/lance/table/FlinkSqlDemo.java +++ b/src/test/java/org/apache/flink/connector/lance/table/FlinkSqlDemo.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.table; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -23,9 +18,9 @@ import org.apache.flink.table.api.TableEnvironment; import org.apache.flink.table.api.TableResult; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; - import org.apache.flink.types.Row; import org.apache.flink.util.CloseableIterator; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; @@ -38,761 +33,811 @@ /** * Flink SQL complete demo test script. - * + * *

This test demonstrates how to use Flink SQL to operate Lance datasets: + * *

    - *
  • Create Lance Catalog
  • - *
  • Create Lance tables
  • - *
  • Insert vector data
  • - *
  • Query data
  • - *
  • Build vector index
  • - *
  • Execute vector search
  • + *
  • Create Lance Catalog + *
  • Create Lance tables + *
  • Insert vector data + *
  • Query data + *
  • Build vector index + *
  • Execute vector search *
*/ class FlinkSqlDemo { - @TempDir - Path tempDir; - - private TableEnvironment tableEnv; - private String warehousePath; - private String datasetPath; - - @BeforeEach - void setUp() { - // Create Flink Table environment - EnvironmentSettings settings = EnvironmentSettings.newInstance() - .inBatchMode() - .build(); - tableEnv = TableEnvironment.create(settings); - - // Set paths - warehousePath = tempDir.resolve("lance_warehouse").toString(); - datasetPath = tempDir.resolve("lance_dataset").toString(); - } + @TempDir Path tempDir; - @AfterEach - void tearDown() { - // Cleanup resources - if (tableEnv != null) { - // TableEnvironment auto cleanup - } - } + private TableEnvironment tableEnv; + private String warehousePath; + private String datasetPath; - // ==================== Basic SQL Operations ==================== - - @Test - @DisplayName("1. Create Lance Connector Table - Basic Usage") - void testCreateLanceTable() throws Exception { - String createTableSql = String.format( - "CREATE TABLE lance_vectors (\n" + - " id BIGINT,\n" + - " content STRING,\n" + - " embedding ARRAY,\n" + - " category STRING,\n" + - " create_time TIMESTAMP(3)\n" + - ") WITH (\n" + - " 'connector' = 'lance',\n" + - " 'path' = '%s',\n" + - " 'write.batch-size' = '1024',\n" + - " 'write.mode' = 'overwrite'\n" + - ")", datasetPath); - - System.out.println("========== Create Lance Table =========="); - System.out.println(createTableSql); - System.out.println(); - - tableEnv.executeSql(createTableSql); - System.out.println("✅ Table created successfully!\n"); - } + @BeforeEach + void setUp() { + // Create Flink Table environment + EnvironmentSettings settings = EnvironmentSettings.newInstance().inBatchMode().build(); + tableEnv = TableEnvironment.create(settings); - @Test - @DisplayName("2. Insert Vector Data to Lance Table") - void testInsertData() throws Exception { - // Use relative path based on project root - Path path = Paths.get(System.getProperty("user.dir"), "test-data"); - // First create table - String createTableSql = String.format( - "CREATE TABLE lance_documents (\n" + - " id BIGINT,\n" + - " title STRING,\n" + - " embedding ARRAY\n" + - ") WITH (\n" + - " 'connector' = 'lance',\n" + - " 'path' = '%s',\n" + - " 'write.mode' = 'overwrite'\n" + - ")", path.resolve("lance-db1")); - - tableEnv.executeSql(createTableSql); - - // Insert data - String insertSql = - "INSERT INTO lance_documents VALUES\n" + - " (1, 'Introduction to AI', ARRAY[0.1, 0.2, 0.3, 0.4]),\n" + - " (2, 'Machine Learning Guide', ARRAY[0.2, 0.3, 0.4, 0.5]),\n" + - " (3, 'Deep Learning Basics', ARRAY[0.3, 0.4, 0.5, 0.6]),\n" + - " (4, 'Neural Networks', ARRAY[0.4, 0.5, 0.6, 0.7]),\n" + - " (5, 'Computer Vision', ARRAY[0.5, 0.6, 0.7, 0.8])"; - - System.out.println("========== Insert Vector Data =========="); - System.out.println(insertSql); - System.out.println(); - - TableResult result = tableEnv.executeSql(insertSql); - result.await(30, TimeUnit.SECONDS); - System.out.println("✅ Data inserted successfully!\n"); - } + // Set paths + warehousePath = tempDir.resolve("lance_warehouse").toString(); + datasetPath = tempDir.resolve("lance_dataset").toString(); + } - @Test - @DisplayName("3. Query Lance Table Data") - void testSelectData() throws Exception { - // Create source table (for generating test data) - String createSourceSql = - "CREATE TABLE test_source (\n" + - " id BIGINT,\n" + - " name STRING\n" + - ") WITH (\n" + - " 'connector' = 'datagen',\n" + - " 'rows-per-second' = '1',\n" + - " 'number-of-rows' = '10',\n" + - " 'fields.id.kind' = 'sequence',\n" + - " 'fields.id.start' = '1',\n" + - " 'fields.id.end' = '10'\n" + - ")"; - - tableEnv.executeSql(createSourceSql); - - // Query data - String selectSql = "SELECT id, name FROM test_source LIMIT 5"; - - System.out.println("========== Query Data =========="); - System.out.println(selectSql); - System.out.println(); - - TableResult result = tableEnv.executeSql(selectSql); - result.print(); - System.out.println("✅ Query completed!\n"); + @AfterEach + void tearDown() { + // Cleanup resources + if (tableEnv != null) { + // TableEnvironment auto cleanup } + } - // ==================== Advanced Configuration ==================== - - @Test - @DisplayName("4. Create Table with Vector Index Configuration") - void testCreateTableWithIndexConfig() throws Exception { - String createTableSql = String.format( - "CREATE TABLE vector_store (\n" + - " id BIGINT,\n" + - " text STRING,\n" + - " embedding ARRAY COMMENT '768-dim vector'\n" + - ") WITH (\n" + - " 'connector' = 'lance',\n" + - " 'path' = '%s',\n" + - " -- Write configuration\n" + - " 'write.batch-size' = '2048',\n" + - " 'write.mode' = 'append',\n" + - " 'write.max-rows-per-file' = '100000',\n" + - " -- Index configuration\n" + - " 'index.type' = 'IVF_PQ',\n" + - " 'index.column' = 'embedding',\n" + - " 'index.num-partitions' = '256',\n" + - " 'index.num-sub-vectors' = '16',\n" + - " -- Vector search configuration\n" + - " 'vector.column' = 'embedding',\n" + - " 'vector.metric' = 'L2',\n" + - " 'vector.nprobes' = '20'\n" + - ")", datasetPath); - - System.out.println("========== Create Table with Index Configuration =========="); - System.out.println(createTableSql); - System.out.println(); - - tableEnv.executeSql(createTableSql); - System.out.println("✅ Table created successfully!\n"); - } + // ==================== Basic SQL Operations ==================== - @Test - @DisplayName("5. Different Index Type Configuration Examples") - void testDifferentIndexTypes() { - System.out.println("========== Index Type Configuration Examples ==========\n"); - - // IVF_PQ index (recommended, balances accuracy and speed) - String ivfPqConfig = - "-- IVF_PQ index configuration (recommended for large-scale vector data)\n" + - "'index.type' = 'IVF_PQ',\n" + - "'index.num-partitions' = '256', -- Number of cluster centers\n" + - "'index.num-sub-vectors' = '16', -- Number of sub-vectors\n" + - "'index.num-bits' = '8' -- Quantization bits per sub-vector\n"; - - System.out.println(ivfPqConfig); - - // IVF_HNSW index (high accuracy) - String ivfHnswConfig = - "-- IVF_HNSW index configuration (for high accuracy scenarios)\n" + - "'index.type' = 'IVF_HNSW',\n" + - "'index.num-partitions' = '256',\n" + - "'index.max-level' = '7', -- HNSW max level\n" + - "'index.m' = '16', -- HNSW connections per level\n" + - "'index.ef-construction' = '100' -- ef parameter during construction\n"; - - System.out.println(ivfHnswConfig); - - // IVF_FLAT index (highest accuracy, suitable for small datasets) - String ivfFlatConfig = - "-- IVF_FLAT index configuration (for small-scale datasets)\n" + - "'index.type' = 'IVF_FLAT',\n" + - "'index.num-partitions' = '64' -- Number of cluster centers\n"; - - System.out.println(ivfFlatConfig); - System.out.println("✅ Configuration examples displayed!\n"); - } + @Test + @DisplayName("1. Create Lance Connector Table - Basic Usage") + void testCreateLanceTable() throws Exception { + String createTableSql = + String.format( + "CREATE TABLE lance_vectors (\n" + + " id BIGINT,\n" + + " content STRING,\n" + + " embedding ARRAY,\n" + + " category STRING,\n" + + " create_time TIMESTAMP(3)\n" + + ") WITH (\n" + + " 'connector' = 'lance',\n" + + " 'path' = '%s',\n" + + " 'write.batch-size' = '1024',\n" + + " 'write.mode' = 'overwrite'\n" + + ")", + datasetPath); - @Test - @DisplayName("6. Distance Metric Type Configuration Examples") - void testMetricTypes() { - System.out.println("========== Distance Metric Type Examples ==========\n"); - - String l2Config = - "-- L2 distance (Euclidean distance, default)\n" + - "'vector.metric' = 'L2'\n" + - "-- Suitable for: General vector search\n"; - System.out.println(l2Config); - - String cosineConfig = - "-- Cosine distance (Cosine similarity)\n" + - "'vector.metric' = 'COSINE'\n" + - "-- Suitable for: Text semantic similarity\n"; - System.out.println(cosineConfig); - - String dotConfig = - "-- Dot distance (Dot product)\n" + - "'vector.metric' = 'DOT'\n" + - "-- Suitable for: Already normalized vectors\n"; - System.out.println(dotConfig); - - System.out.println("✅ Configuration examples displayed!\n"); - } + System.out.println("========== Create Lance Table =========="); + System.out.println(createTableSql); + System.out.println(); - // ==================== Catalog Operations ==================== - - @Test - @DisplayName("7. Create and Use Lance Catalog") - void testLanceCatalog() throws Exception { - String createCatalogSql = String.format( - "CREATE CATALOG lance_catalog WITH (\n" + - " 'type' = 'lance',\n" + - " 'warehouse' = '%s',\n" + - " 'default-database' = 'default'\n" + - ")", warehousePath); - - System.out.println("========== Create Lance Catalog =========="); - System.out.println(createCatalogSql); - System.out.println(); - - tableEnv.executeSql(createCatalogSql); - - // Use Catalog - tableEnv.executeSql("USE CATALOG lance_catalog"); - System.out.println("✅ Catalog created and switched!\n"); - - // Create database - tableEnv.executeSql("CREATE DATABASE IF NOT EXISTS vector_db"); - System.out.println("✅ Database vector_db created!\n"); - - // List databases - System.out.println("Database list:"); - tableEnv.executeSql("SHOW DATABASES").print(); - } + tableEnv.executeSql(createTableSql); + System.out.println("✅ Table created successfully!\n"); + } - // ==================== Streaming Processing ==================== - - @Test - @DisplayName("8. Streaming Write to Lance Table") - void testStreamingWrite() throws Exception { - // Create streaming environment - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setParallelism(1); - StreamTableEnvironment streamTableEnv = StreamTableEnvironment.create(env); - - // Create data generator table (simulating real-time data) - String createSourceSql = - "CREATE TABLE realtime_events (\n" + - " event_id BIGINT,\n" + - " event_type STRING,\n" + - " event_time AS PROCTIME()\n" + - ") WITH (\n" + - " 'connector' = 'datagen',\n" + - " 'rows-per-second' = '10',\n" + - " 'number-of-rows' = '100',\n" + - " 'fields.event_id.kind' = 'sequence',\n" + - " 'fields.event_id.start' = '1',\n" + - " 'fields.event_id.end' = '100',\n" + - " 'fields.event_type.length' = '10'\n" + - ")"; - - // Create Lance Sink table - String createSinkSql = String.format( - "CREATE TABLE lance_events (\n" + - " event_id BIGINT,\n" + - " event_type STRING\n" + - ") WITH (\n" + - " 'connector' = 'lance',\n" + - " 'path' = '%s',\n" + - " 'write.batch-size' = '100',\n" + - " 'write.mode' = 'append'\n" + - ")", datasetPath); - - System.out.println("========== Streaming Write Example =========="); - System.out.println("-- Source table definition"); - System.out.println(createSourceSql); - System.out.println("\n-- Sink table definition"); - System.out.println(createSinkSql); - System.out.println(); - - streamTableEnv.executeSql(createSourceSql); - streamTableEnv.executeSql(createSinkSql); - - // Execute streaming write - String insertSql = "INSERT INTO lance_events SELECT event_id, event_type FROM realtime_events"; - System.out.println("-- Streaming insert statement"); - System.out.println(insertSql); - System.out.println(); - - System.out.println("✅ Streaming write configuration completed!\n"); - } + @Test + @DisplayName("2. Insert Vector Data to Lance Table") + void testInsertData() throws Exception { + // Use relative path based on project root + Path path = Paths.get(System.getProperty("user.dir"), "test-data"); + // First create table + String createTableSql = + String.format( + "CREATE TABLE lance_documents (\n" + + " id BIGINT,\n" + + " title STRING,\n" + + " embedding ARRAY\n" + + ") WITH (\n" + + " 'connector' = 'lance',\n" + + " 'path' = '%s',\n" + + " 'write.mode' = 'overwrite'\n" + + ")", + path.resolve("lance-db1")); - // ==================== Complete Example ==================== - - @Test - @DisplayName("9. Complete Vector Storage and Search Example") - void testCompleteVectorExample() throws Exception { - // Use relative path based on project root - Path path = Paths.get(System.getProperty("user.dir"), "test-data"); - System.out.println("========== Complete Vector Storage and Search Example ==========\n"); - - // 1. Create vector table - String createTableSql = String.format( - "-- 1. Create vector storage table\n" + - "CREATE TABLE document_vectors (\n" + - " doc_id BIGINT COMMENT 'Document ID',\n" + - " title STRING COMMENT 'Document title',\n" + - " content STRING COMMENT 'Document content',\n" + - " embedding ARRAY COMMENT 'Document vector (768-dim)',\n" + - " category STRING COMMENT 'Document category',\n" + - " create_time TIMESTAMP(3) COMMENT 'Creation time'\n" + - ") WITH (\n" + - " 'connector' = 'lance',\n" + - " 'path' = '%s',\n" + - " -- Write configuration\n" + - " 'write.batch-size' = '1024',\n" + - " 'write.mode' = 'overwrite',\n" + - " -- Index configuration\n" + - " 'index.type' = 'IVF_PQ',\n" + - " 'index.column' = 'embedding',\n" + - " 'index.num-partitions' = '128',\n" + - " 'index.num-sub-vectors' = '32',\n" + - " -- Vector search configuration\n" + - " 'vector.column' = 'embedding',\n" + - " 'vector.metric' = 'COSINE',\n" + - " 'vector.nprobes' = '10'\n" + - ")", path.resolve("lance-db3")); - - System.out.println(createTableSql); - System.out.println(); - tableEnv.executeSql(createTableSql); - - // 2. Insert test data - String insertSql = - "-- 2. Insert vector data\n" + - "INSERT INTO document_vectors VALUES\n" + - " (1, 'Flink Getting Started Guide', 'Introduction to Apache Flink basics...', \n" + - " ARRAY[0.1, 0.2, 0.3, 0.4], 'tutorial', TIMESTAMP '2024-01-01 10:00:00'),\n" + - " (2, 'Stream Processing in Practice', 'Using Flink to process real-time data streams...', \n" + - " ARRAY[0.2, 0.3, 0.4, 0.5], 'practice', TIMESTAMP '2024-01-02 11:00:00'),\n" + - " (3, 'Vector Database Explained', 'Deep understanding of vector search technology...', \n" + - " ARRAY[0.3, 0.4, 0.5, 0.6], 'database', TIMESTAMP '2024-01-03 12:00:00'),\n" + - " (4, 'Lance Format Introduction', 'Lance is an efficient vector storage format...', \n" + - " ARRAY[0.4, 0.5, 0.6, 0.7], 'format', TIMESTAMP '2024-01-04 13:00:00'),\n" + - " (5, 'SQL Connector Development', 'How to develop Flink SQL connectors...', \n" + - " ARRAY[0.5, 0.6, 0.7, 0.8], 'development', TIMESTAMP '2024-01-05 14:00:00')"; - - System.out.println(insertSql); - System.out.println(); - TableResult result = tableEnv.executeSql(insertSql); - result.await(30, TimeUnit.SECONDS); - - // 3. Query data - String selectSql = - "-- 3. Query vector data\n" + - "SELECT doc_id, title, category, create_time\n" + - "FROM document_vectors\n" + - "WHERE category = 'tutorial'\n" + - "ORDER BY create_time DESC"; - - System.out.println(selectSql); - System.out.println(); - TableResult tableResult = tableEnv.executeSql(selectSql); - tableResult.await(3, TimeUnit.SECONDS); - CloseableIterator collect = tableResult.collect(); - while (collect.hasNext()) { - System.out.println(collect.next()); - } - - // 4. Aggregation query - String aggSql = - "-- 4. Count documents by category\n" + - "SELECT category, COUNT(*) as doc_count\n" + - "FROM document_vectors\n" + - "GROUP BY category\n" + - "ORDER BY doc_count DESC"; - - System.out.println(aggSql); - System.out.println(); - tableEnv.executeSql(aggSql).print(); - - System.out.println("✅ Complete example displayed!\n"); - } + tableEnv.executeSql(createTableSql); + + // Insert data + String insertSql = + "INSERT INTO lance_documents VALUES\n" + + " (1, 'Introduction to AI', ARRAY[0.1, 0.2, 0.3, 0.4]),\n" + + " (2, 'Machine Learning Guide', ARRAY[0.2, 0.3, 0.4, 0.5]),\n" + + " (3, 'Deep Learning Basics', ARRAY[0.3, 0.4, 0.5, 0.6]),\n" + + " (4, 'Neural Networks', ARRAY[0.4, 0.5, 0.6, 0.7]),\n" + + " (5, 'Computer Vision', ARRAY[0.5, 0.6, 0.7, 0.8])"; + + System.out.println("========== Insert Vector Data =========="); + System.out.println(insertSql); + System.out.println(); + + TableResult result = tableEnv.executeSql(insertSql); + result.await(30, TimeUnit.SECONDS); + System.out.println("✅ Data inserted successfully!\n"); + } + + @Test + @DisplayName("3. Query Lance Table Data") + void testSelectData() throws Exception { + // Create source table (for generating test data) + String createSourceSql = + "CREATE TABLE test_source (\n" + + " id BIGINT,\n" + + " name STRING\n" + + ") WITH (\n" + + " 'connector' = 'datagen',\n" + + " 'rows-per-second' = '1',\n" + + " 'number-of-rows' = '10',\n" + + " 'fields.id.kind' = 'sequence',\n" + + " 'fields.id.start' = '1',\n" + + " 'fields.id.end' = '10'\n" + + ")"; + + tableEnv.executeSql(createSourceSql); + + // Query data + String selectSql = "SELECT id, name FROM test_source LIMIT 5"; + + System.out.println("========== Query Data =========="); + System.out.println(selectSql); + System.out.println(); + + TableResult result = tableEnv.executeSql(selectSql); + result.print(); + System.out.println("✅ Query completed!\n"); + } + + // ==================== Advanced Configuration ==================== + + @Test + @DisplayName("4. Create Table with Vector Index Configuration") + void testCreateTableWithIndexConfig() throws Exception { + String createTableSql = + String.format( + "CREATE TABLE vector_store (\n" + + " id BIGINT,\n" + + " text STRING,\n" + + " embedding ARRAY COMMENT '768-dim vector'\n" + + ") WITH (\n" + + " 'connector' = 'lance',\n" + + " 'path' = '%s',\n" + + " -- Write configuration\n" + + " 'write.batch-size' = '2048',\n" + + " 'write.mode' = 'append',\n" + + " 'write.max-rows-per-file' = '100000',\n" + + " -- Index configuration\n" + + " 'index.type' = 'IVF_PQ',\n" + + " 'index.column' = 'embedding',\n" + + " 'index.num-partitions' = '256',\n" + + " 'index.num-sub-vectors' = '16',\n" + + " -- Vector search configuration\n" + + " 'vector.column' = 'embedding',\n" + + " 'vector.metric' = 'L2',\n" + + " 'vector.nprobes' = '20'\n" + + ")", + datasetPath); + + System.out.println("========== Create Table with Index Configuration =========="); + System.out.println(createTableSql); + System.out.println(); + + tableEnv.executeSql(createTableSql); + System.out.println("✅ Table created successfully!\n"); + } + + @Test + @DisplayName("5. Different Index Type Configuration Examples") + void testDifferentIndexTypes() { + System.out.println("========== Index Type Configuration Examples ==========\n"); + + // IVF_PQ index (recommended, balances accuracy and speed) + String ivfPqConfig = + "-- IVF_PQ index configuration (recommended for large-scale vector data)\n" + + "'index.type' = 'IVF_PQ',\n" + + "'index.num-partitions' = '256', -- Number of cluster centers\n" + + "'index.num-sub-vectors' = '16', -- Number of sub-vectors\n" + + "'index.num-bits' = '8' -- Quantization bits per sub-vector\n"; + + System.out.println(ivfPqConfig); + + // IVF_HNSW index (high accuracy) + String ivfHnswConfig = + "-- IVF_HNSW index configuration (for high accuracy scenarios)\n" + + "'index.type' = 'IVF_HNSW',\n" + + "'index.num-partitions' = '256',\n" + + "'index.max-level' = '7', -- HNSW max level\n" + + "'index.m' = '16', -- HNSW connections per level\n" + + "'index.ef-construction' = '100' -- ef parameter during construction\n"; + + System.out.println(ivfHnswConfig); + + // IVF_FLAT index (highest accuracy, suitable for small datasets) + String ivfFlatConfig = + "-- IVF_FLAT index configuration (for small-scale datasets)\n" + + "'index.type' = 'IVF_FLAT',\n" + + "'index.num-partitions' = '64' -- Number of cluster centers\n"; + + System.out.println(ivfFlatConfig); + System.out.println("✅ Configuration examples displayed!\n"); + } + + @Test + @DisplayName("6. Distance Metric Type Configuration Examples") + void testMetricTypes() { + System.out.println("========== Distance Metric Type Examples ==========\n"); + + String l2Config = + "-- L2 distance (Euclidean distance, default)\n" + + "'vector.metric' = 'L2'\n" + + "-- Suitable for: General vector search\n"; + System.out.println(l2Config); + + String cosineConfig = + "-- Cosine distance (Cosine similarity)\n" + + "'vector.metric' = 'COSINE'\n" + + "-- Suitable for: Text semantic similarity\n"; + System.out.println(cosineConfig); + + String dotConfig = + "-- Dot distance (Dot product)\n" + + "'vector.metric' = 'DOT'\n" + + "-- Suitable for: Already normalized vectors\n"; + System.out.println(dotConfig); + + System.out.println("✅ Configuration examples displayed!\n"); + } + + // ==================== Catalog Operations ==================== + + @Test + @DisplayName("7. Create and Use Lance Catalog") + void testLanceCatalog() throws Exception { + String createCatalogSql = + String.format( + "CREATE CATALOG lance_catalog WITH (\n" + + " 'type' = 'lance',\n" + + " 'warehouse' = '%s',\n" + + " 'default-database' = 'default'\n" + + ")", + warehousePath); + + System.out.println("========== Create Lance Catalog =========="); + System.out.println(createCatalogSql); + System.out.println(); + + tableEnv.executeSql(createCatalogSql); + + // Use Catalog + tableEnv.executeSql("USE CATALOG lance_catalog"); + System.out.println("✅ Catalog created and switched!\n"); + + // Create database + tableEnv.executeSql("CREATE DATABASE IF NOT EXISTS vector_db"); + System.out.println("✅ Database vector_db created!\n"); + + // List databases + System.out.println("Database list:"); + tableEnv.executeSql("SHOW DATABASES").print(); + } + + // ==================== Streaming Processing ==================== + + @Test + @DisplayName("8. Streaming Write to Lance Table") + void testStreamingWrite() throws Exception { + // Create streaming environment + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + StreamTableEnvironment streamTableEnv = StreamTableEnvironment.create(env); + + // Create data generator table (simulating real-time data) + String createSourceSql = + "CREATE TABLE realtime_events (\n" + + " event_id BIGINT,\n" + + " event_type STRING,\n" + + " event_time AS PROCTIME()\n" + + ") WITH (\n" + + " 'connector' = 'datagen',\n" + + " 'rows-per-second' = '10',\n" + + " 'number-of-rows' = '100',\n" + + " 'fields.event_id.kind' = 'sequence',\n" + + " 'fields.event_id.start' = '1',\n" + + " 'fields.event_id.end' = '100',\n" + + " 'fields.event_type.length' = '10'\n" + + ")"; + + // Create Lance Sink table + String createSinkSql = + String.format( + "CREATE TABLE lance_events (\n" + + " event_id BIGINT,\n" + + " event_type STRING\n" + + ") WITH (\n" + + " 'connector' = 'lance',\n" + + " 'path' = '%s',\n" + + " 'write.batch-size' = '100',\n" + + " 'write.mode' = 'append'\n" + + ")", + datasetPath); + + System.out.println("========== Streaming Write Example =========="); + System.out.println("-- Source table definition"); + System.out.println(createSourceSql); + System.out.println("\n-- Sink table definition"); + System.out.println(createSinkSql); + System.out.println(); + + streamTableEnv.executeSql(createSourceSql); + streamTableEnv.executeSql(createSinkSql); + + // Execute streaming write + String insertSql = "INSERT INTO lance_events SELECT event_id, event_type FROM realtime_events"; + System.out.println("-- Streaming insert statement"); + System.out.println(insertSql); + System.out.println(); - @Test - @DisplayName("9.1 Vector Search IVF_PQ Index Example") - void testVectorSearchWithIvfPq() throws Exception { - System.out.println("========== Vector Search IVF_PQ Index Example =========="); - - // Use relative path based on project root - Path basePath = Paths.get(System.getProperty("user.dir"), "test-data"); - String datasetPath = basePath.resolve("lance-vector-search").toString(); - - // ============================================ - // Step 1: Create vector table with IVF_PQ index configuration - // ============================================ - String createTableSql = String.format( - "CREATE TABLE vector_documents (\n" + - " id BIGINT,\n" + - " title STRING,\n" + - " embedding ARRAY\n" + - ") WITH (\n" + - " 'connector' = 'lance',\n" + - " 'path' = '%s',\n" + - " 'write.batch-size' = '1024',\n" + - " 'write.mode' = 'overwrite',\n" + - " -- IVF_PQ index configuration\n" + - " 'index.type' = 'IVF_PQ',\n" + - " 'index.column' = 'embedding',\n" + - " 'index.num-partitions' = '16',\n" + - " 'index.num-sub-vectors' = '8',\n" + - " -- Vector search configuration\n" + - " 'vector.column' = 'embedding',\n" + - " 'vector.metric' = 'L2',\n" + - " 'vector.nprobes' = '10'\n" + - ")", datasetPath); - - System.out.println("-- Step 1: Create vector table with IVF_PQ index configuration"); - System.out.println(createTableSql); - System.out.println(); - tableEnv.executeSql(createTableSql); - - // ============================================ - // Step 2: Insert vector data - // ============================================ - String insertSql = - "INSERT INTO vector_documents VALUES\n" + - " (1, 'Flink Stream Processing', ARRAY[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]),\n" + - " (2, 'Spark Batch Processing', ARRAY[0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),\n" + - " (3, 'Kafka Message Queue', ARRAY[0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]),\n" + - " (4, 'Vector Database', ARRAY[0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85]),\n" + - " (5, 'Machine Learning Basics', ARRAY[0.12, 0.22, 0.32, 0.42, 0.52, 0.62, 0.72, 0.82])"; - - System.out.println("-- Step 2: Insert vector data"); - System.out.println(insertSql); - System.out.println(); - tableEnv.executeSql(insertSql).await(30, TimeUnit.SECONDS); - System.out.println("✅ Data insertion completed\n"); - - // ============================================ - // Step 3: Register vector search UDF - // ============================================ - String createFunctionSql = - "CREATE TEMPORARY FUNCTION vector_search AS \n" + - " 'org.apache.flink.connector.lance.table.LanceVectorSearchFunction'"; - - System.out.println("-- Step 3: Register vector search UDF"); - System.out.println(createFunctionSql); - System.out.println(); - tableEnv.executeSql(createFunctionSql); - System.out.println("✅ UDF registration completed\n"); - - // ============================================ - // Step 4: Execute vector search - Basic usage - // ============================================ - System.out.println("-- Step 4: Execute vector search (Basic usage)"); - System.out.println("-- Parameter description:"); - System.out.println("-- Param 1: Dataset path"); - System.out.println("-- Param 2: Vector column name"); - System.out.println("-- Param 3: Query vector"); - System.out.println("-- Param 4: TopK count to return"); - System.out.println("-- Param 5: Distance metric type (L2/COSINE/DOT)"); - System.out.println(); - - String vectorSearchSql = String.format( - "SELECT * FROM TABLE(\n" + - " vector_search(\n" + - " '%s', -- Dataset path\n" + - " 'embedding', -- Vector column name\n" + - " ARRAY[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], -- Query vector\n" + - " 3, -- Return Top 3\n" + - " 'L2' -- L2 distance metric\n" + - " )\n" + - ")", datasetPath); - - System.out.println(vectorSearchSql); - System.out.println(); - System.out.println("📊 Search results (sorted by L2 distance, smaller distance = more similar):"); - System.out.println("---------------------------------------------------"); - - try { - TableResult result = tableEnv.executeSql(vectorSearchSql); - result.print(); - } catch (Exception e) { - System.out.println("⚠️ Vector search execution error: " + e.getMessage()); - System.out.println(" This may be because the dataset needs to build index first"); - } - - // ============================================ - // Step 5: Use COSINE cosine similarity search - // ============================================ - System.out.println("\n-- Step 5: Use COSINE cosine similarity search"); - - String cosineSearchSql = String.format( - "SELECT * FROM TABLE(\n" + - " vector_search(\n" + - " '%s',\n" + - " 'embedding',\n" + - " ARRAY[0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1],\n" + - " 3,\n" + - " 'COSINE' -- Cosine similarity\n" + - " )\n" + - ")", datasetPath); - - System.out.println(cosineSearchSql); - System.out.println(); - System.out.println("📊 Search results (sorted by cosine distance):"); - System.out.println("---------------------------------------------------"); - - try { - tableEnv.executeSql(cosineSearchSql).print(); - } catch (Exception e) { - System.out.println("⚠️ Execution error: " + e.getMessage()); - } - - // ============================================ - // Step 6: Combine vector search with other queries - // ============================================ - System.out.println("\n-- Step 6: Combine vector search with other queries (LATERAL TABLE)"); - - String lateralSearchSql = String.format( - "-- First query data, then perform vector search based on results\n" + - "SELECT \n" + - " v.id,\n" + - " v.title,\n" + - " v._distance as similarity_distance\n" + - "FROM TABLE(\n" + - " vector_search('%s', 'embedding', ARRAY[0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85], 5, 'L2')\n" + - ") AS v\n" + - "WHERE v._distance < 1.0 -- Only return results with distance less than 1", datasetPath); - - System.out.println(lateralSearchSql); - System.out.println(); - - // ============================================ - // Print configuration parameter descriptions - // ============================================ - System.out.println("\n========== IVF_PQ Index Configuration Parameter Description =========="); - System.out.println("╔═════════════════════════════╦════════════════════════════════════════════════════╗"); - System.out.println("║ Configuration ║ Description ║"); - System.out.println("╠═════════════════════════════╬════════════════════════════════════════════════════╣"); - System.out.println("║ index.type = 'IVF_PQ' ║ Use IVF_PQ index type ║"); - System.out.println("║ index.column ║ Vector column name to build index on ║"); - System.out.println("║ index.num-partitions ║ IVF partition count, recommend: sqrt(n) to 4*sqrt(n)║"); - System.out.println("║ index.num-sub-vectors ║ PQ sub-vector count, must divide vector dimension ║"); - System.out.println("║ index.num-bits ║ PQ encoding bits, default 8 (256 cluster centers) ║"); - System.out.println("║ vector.metric ║ Distance metric: L2(Euclidean)/COSINE/DOT(Dot product)║"); - System.out.println("║ vector.nprobes ║ Number of partitions to probe during search ║"); - System.out.println("╚═════════════════════════════╩════════════════════════════════════════════════════╝"); - - System.out.println("\n========== Distance Metric Type Description =========="); - System.out.println("╔════════════════╦════════════════════════════════════════════════════════════════╗"); - System.out.println("║ Metric Type ║ Description ║"); - System.out.println("╠════════════════╬════════════════════════════════════════════════════════════════╣"); - System.out.println("║ L2 ║ Euclidean distance, smaller = more similar, for dense vectors ║"); - System.out.println("║ COSINE ║ Cosine distance, range [0,2], smaller = more similar, for text║"); - System.out.println("║ DOT ║ Negative dot product, smaller = more similar (needs normalization)║"); - System.out.println("╚════════════════╩════════════════════════════════════════════════════════════════╝"); - - System.out.println("\n✅ Vector search IVF_PQ example completed!\n"); + System.out.println("✅ Streaming write configuration completed!\n"); + } + + // ==================== Complete Example ==================== + + @Test + @DisplayName("9. Complete Vector Storage and Search Example") + void testCompleteVectorExample() throws Exception { + // Use relative path based on project root + Path path = Paths.get(System.getProperty("user.dir"), "test-data"); + System.out.println("========== Complete Vector Storage and Search Example ==========\n"); + + // 1. Create vector table + String createTableSql = + String.format( + "-- 1. Create vector storage table\n" + + "CREATE TABLE document_vectors (\n" + + " doc_id BIGINT COMMENT 'Document ID',\n" + + " title STRING COMMENT 'Document title',\n" + + " content STRING COMMENT 'Document content',\n" + + " embedding ARRAY COMMENT 'Document vector (768-dim)',\n" + + " category STRING COMMENT 'Document category',\n" + + " create_time TIMESTAMP(3) COMMENT 'Creation time'\n" + + ") WITH (\n" + + " 'connector' = 'lance',\n" + + " 'path' = '%s',\n" + + " -- Write configuration\n" + + " 'write.batch-size' = '1024',\n" + + " 'write.mode' = 'overwrite',\n" + + " -- Index configuration\n" + + " 'index.type' = 'IVF_PQ',\n" + + " 'index.column' = 'embedding',\n" + + " 'index.num-partitions' = '128',\n" + + " 'index.num-sub-vectors' = '32',\n" + + " -- Vector search configuration\n" + + " 'vector.column' = 'embedding',\n" + + " 'vector.metric' = 'COSINE',\n" + + " 'vector.nprobes' = '10'\n" + + ")", + path.resolve("lance-db3")); + + System.out.println(createTableSql); + System.out.println(); + tableEnv.executeSql(createTableSql); + + // 2. Insert test data + String insertSql = + "-- 2. Insert vector data\n" + + "INSERT INTO document_vectors VALUES\n" + + " (1, 'Flink Getting Started Guide', 'Introduction to Apache Flink basics...', \n" + + " ARRAY[0.1, 0.2, 0.3, 0.4], 'tutorial', TIMESTAMP '2024-01-01 10:00:00'),\n" + + " (2, 'Stream Processing in Practice', 'Using Flink to process real-time data streams...', \n" + + " ARRAY[0.2, 0.3, 0.4, 0.5], 'practice', TIMESTAMP '2024-01-02 11:00:00'),\n" + + " (3, 'Vector Database Explained', 'Deep understanding of vector search technology...', \n" + + " ARRAY[0.3, 0.4, 0.5, 0.6], 'database', TIMESTAMP '2024-01-03 12:00:00'),\n" + + " (4, 'Lance Format Introduction', 'Lance is an efficient vector storage format...', \n" + + " ARRAY[0.4, 0.5, 0.6, 0.7], 'format', TIMESTAMP '2024-01-04 13:00:00'),\n" + + " (5, 'SQL Connector Development', 'How to develop Flink SQL connectors...', \n" + + " ARRAY[0.5, 0.6, 0.7, 0.8], 'development', TIMESTAMP '2024-01-05 14:00:00')"; + + System.out.println(insertSql); + System.out.println(); + TableResult result = tableEnv.executeSql(insertSql); + result.await(30, TimeUnit.SECONDS); + + // 3. Query data + String selectSql = + "-- 3. Query vector data\n" + + "SELECT doc_id, title, category, create_time\n" + + "FROM document_vectors\n" + + "WHERE category = 'tutorial'\n" + + "ORDER BY create_time DESC"; + + System.out.println(selectSql); + System.out.println(); + TableResult tableResult = tableEnv.executeSql(selectSql); + tableResult.await(3, TimeUnit.SECONDS); + CloseableIterator collect = tableResult.collect(); + while (collect.hasNext()) { + System.out.println(collect.next()); } - @Test - @DisplayName("9.2 Different Index Types Comparison Example") - void testDifferentIndexTypesDetailed() throws Exception { - System.out.println("========== Different Vector Index Types Comparison =========="); - - // Use relative path based on project root - Path basePath = Paths.get(System.getProperty("user.dir"), "test-data"); - - // ============================================ - // IVF_PQ Index - For large-scale data, low memory footprint - // ============================================ - System.out.println("【1. IVF_PQ Index】- Recommended for large-scale data"); - System.out.println("Pros: Low memory footprint, fast search speed"); - System.out.println("Cons: Lower accuracy (quantization loss)"); - System.out.println(); - - String ivfPqSql = String.format( - "CREATE TABLE ivf_pq_vectors (\n" + - " id BIGINT,\n" + - " embedding ARRAY\n" + - ") WITH (\n" + - " 'connector' = 'lance',\n" + - " 'path' = '%s',\n" + - " 'index.type' = 'IVF_PQ',\n" + - " 'index.column' = 'embedding',\n" + - " 'index.num-partitions' = '256', -- IVF partition count\n" + - " 'index.num-sub-vectors' = '16', -- PQ sub-vector count\n" + - " 'index.num-bits' = '8', -- Encoding bits per sub-vector\n" + - " 'vector.metric' = 'L2'\n" + - ")", basePath.resolve("ivf-pq-demo")); - - System.out.println(ivfPqSql); - System.out.println(); - - // ============================================ - // IVF_HNSW Index - High accuracy search - // ============================================ - System.out.println("【2. IVF_HNSW Index】- Recommended for high accuracy requirements"); - System.out.println("Pros: High search accuracy"); - System.out.println("Cons: Higher memory footprint, slower index building"); - System.out.println(); - - String ivfHnswSql = String.format( - "CREATE TABLE ivf_hnsw_vectors (\n" + - " id BIGINT,\n" + - " embedding ARRAY\n" + - ") WITH (\n" + - " 'connector' = 'lance',\n" + - " 'path' = '%s',\n" + - " 'index.type' = 'IVF_HNSW',\n" + - " 'index.column' = 'embedding',\n" + - " 'index.num-partitions' = '256', -- IVF partition count\n" + - " 'index.hnsw-m' = '16', -- HNSW connections per level\n" + - " 'index.hnsw-ef-construction' = '100', -- Candidate set size during construction\n" + - " 'vector.metric' = 'COSINE',\n" + - " 'vector.ef' = '50' -- Candidate set size during search\n" + - ")", basePath.resolve("ivf-hnsw-demo")); - - System.out.println(ivfHnswSql); - System.out.println(); - - // ============================================ - // IVF_FLAT Index - Highest accuracy, brute force search - // ============================================ - System.out.println("【3. IVF_FLAT Index】- Highest accuracy"); - System.out.println("Pros: 100% search accuracy (lossless)"); - System.out.println("Cons: Slower search speed, suitable for small datasets"); - System.out.println(); - - String ivfFlatSql = String.format( - "CREATE TABLE ivf_flat_vectors (\n" + - " id BIGINT,\n" + - " embedding ARRAY\n" + - ") WITH (\n" + - " 'connector' = 'lance',\n" + - " 'path' = '%s',\n" + - " 'index.type' = 'IVF_FLAT',\n" + - " 'index.column' = 'embedding',\n" + - " 'index.num-partitions' = '128', -- IVF partition count\n" + - " 'vector.metric' = 'DOT',\n" + - " 'vector.nprobes' = '32' -- Number of partitions to probe during search\n" + - ")", basePath.resolve("ivf-flat-demo")); - - System.out.println(ivfFlatSql); - System.out.println(); - - // ============================================ - // Index Selection Recommendations - // ============================================ - System.out.println("========== Index Selection Recommendations =========="); - System.out.println("╔═══════════════════╦════════════════╦═══════════════╦════════════════════════════════╗"); - System.out.println("║ Index Type ║ Data Scale ║ Accuracy ║ Use Case ║"); - System.out.println("╠═══════════════════╬════════════════╬═══════════════╬════════════════════════════════╣"); - System.out.println("║ IVF_PQ ║ 1M+ ║ Medium ║ Large-scale recommendation, image search║"); - System.out.println("║ IVF_HNSW ║ 100K-1M ║ High ║ Semantic search, Q&A systems ║"); - System.out.println("║ IVF_FLAT ║ <100K ║ Highest ║ Small-scale high-precision scenarios║"); - System.out.println("╚═══════════════════╩════════════════╩═══════════════╩════════════════════════════════╝"); - - System.out.println("\n✅ Index type comparison example completed!\n"); + // 4. Aggregation query + String aggSql = + "-- 4. Count documents by category\n" + + "SELECT category, COUNT(*) as doc_count\n" + + "FROM document_vectors\n" + + "GROUP BY category\n" + + "ORDER BY doc_count DESC"; + + System.out.println(aggSql); + System.out.println(); + tableEnv.executeSql(aggSql).print(); + + System.out.println("✅ Complete example displayed!\n"); + } + + @Test + @DisplayName("9.1 Vector Search IVF_PQ Index Example") + void testVectorSearchWithIvfPq() throws Exception { + System.out.println("========== Vector Search IVF_PQ Index Example =========="); + + // Use relative path based on project root + Path basePath = Paths.get(System.getProperty("user.dir"), "test-data"); + String datasetPath = basePath.resolve("lance-vector-search").toString(); + + // ============================================ + // Step 1: Create vector table with IVF_PQ index configuration + // ============================================ + String createTableSql = + String.format( + "CREATE TABLE vector_documents (\n" + + " id BIGINT,\n" + + " title STRING,\n" + + " embedding ARRAY\n" + + ") WITH (\n" + + " 'connector' = 'lance',\n" + + " 'path' = '%s',\n" + + " 'write.batch-size' = '1024',\n" + + " 'write.mode' = 'overwrite',\n" + + " -- IVF_PQ index configuration\n" + + " 'index.type' = 'IVF_PQ',\n" + + " 'index.column' = 'embedding',\n" + + " 'index.num-partitions' = '16',\n" + + " 'index.num-sub-vectors' = '8',\n" + + " -- Vector search configuration\n" + + " 'vector.column' = 'embedding',\n" + + " 'vector.metric' = 'L2',\n" + + " 'vector.nprobes' = '10'\n" + + ")", + datasetPath); + + System.out.println("-- Step 1: Create vector table with IVF_PQ index configuration"); + System.out.println(createTableSql); + System.out.println(); + tableEnv.executeSql(createTableSql); + + // ============================================ + // Step 2: Insert vector data + // ============================================ + String insertSql = + "INSERT INTO vector_documents VALUES\n" + + " (1, 'Flink Stream Processing', ARRAY[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]),\n" + + " (2, 'Spark Batch Processing', ARRAY[0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),\n" + + " (3, 'Kafka Message Queue', ARRAY[0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]),\n" + + " (4, 'Vector Database', ARRAY[0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85]),\n" + + " (5, 'Machine Learning Basics', ARRAY[0.12, 0.22, 0.32, 0.42, 0.52, 0.62, 0.72, 0.82])"; + + System.out.println("-- Step 2: Insert vector data"); + System.out.println(insertSql); + System.out.println(); + tableEnv.executeSql(insertSql).await(30, TimeUnit.SECONDS); + System.out.println("✅ Data insertion completed\n"); + + // ============================================ + // Step 3: Register vector search UDF + // ============================================ + String createFunctionSql = + "CREATE TEMPORARY FUNCTION vector_search AS \n" + + " 'org.apache.flink.connector.lance.table.LanceVectorSearchFunction'"; + + System.out.println("-- Step 3: Register vector search UDF"); + System.out.println(createFunctionSql); + System.out.println(); + tableEnv.executeSql(createFunctionSql); + System.out.println("✅ UDF registration completed\n"); + + // ============================================ + // Step 4: Execute vector search - Basic usage + // ============================================ + System.out.println("-- Step 4: Execute vector search (Basic usage)"); + System.out.println("-- Parameter description:"); + System.out.println("-- Param 1: Dataset path"); + System.out.println("-- Param 2: Vector column name"); + System.out.println("-- Param 3: Query vector"); + System.out.println("-- Param 4: TopK count to return"); + System.out.println("-- Param 5: Distance metric type (L2/COSINE/DOT)"); + System.out.println(); + + String vectorSearchSql = + String.format( + "SELECT * FROM TABLE(\n" + + " vector_search(\n" + + " '%s', -- Dataset path\n" + + " 'embedding', -- Vector column name\n" + + " ARRAY[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], -- Query vector\n" + + " 3, -- Return Top 3\n" + + " 'L2' -- L2 distance metric\n" + + " )\n" + + ")", + datasetPath); + + System.out.println(vectorSearchSql); + System.out.println(); + System.out.println( + "📊 Search results (sorted by L2 distance, smaller distance = more similar):"); + System.out.println("---------------------------------------------------"); + + try { + TableResult result = tableEnv.executeSql(vectorSearchSql); + result.print(); + } catch (Exception e) { + System.out.println("⚠️ Vector search execution error: " + e.getMessage()); + System.out.println(" This may be because the dataset needs to build index first"); } - @Test - @DisplayName("10. SQL Syntax Quick Reference") - void testSqlQuickReference() { - System.out.println("========================================"); - System.out.println(" Flink SQL Lance Connector Quick Reference"); - System.out.println("========================================\n"); - - System.out.println("【Create Table】"); - System.out.println("CREATE TABLE table_name ("); - System.out.println(" column_name data_type,"); - System.out.println(" embedding ARRAY"); - System.out.println(") WITH ("); - System.out.println(" 'connector' = 'lance',"); - System.out.println(" 'path' = '/path/to/dataset'"); - System.out.println(");\n"); - - System.out.println("【Insert Data】"); - System.out.println("INSERT INTO table_name VALUES (1, 'text', ARRAY[0.1, 0.2, 0.3]);\n"); - - System.out.println("【Query Data】"); - System.out.println("SELECT * FROM table_name WHERE condition;\n"); - - System.out.println("【Create Catalog】"); - System.out.println("CREATE CATALOG lance_catalog WITH ("); - System.out.println(" 'type' = 'lance',"); - System.out.println(" 'warehouse' = '/path/to/warehouse'"); - System.out.println(");\n"); - - System.out.println("【Data Type Mapping】"); - System.out.println("╔════════════════════╦═══════════════════╗"); - System.out.println("║ Flink SQL Type ║ Lance Type ║"); - System.out.println("╠════════════════════╬═══════════════════╣"); - System.out.println("║ BOOLEAN ║ Bool ║"); - System.out.println("║ TINYINT ║ Int8 ║"); - System.out.println("║ SMALLINT ║ Int16 ║"); - System.out.println("║ INT ║ Int32 ║"); - System.out.println("║ BIGINT ║ Int64 ║"); - System.out.println("║ FLOAT ║ Float32 ║"); - System.out.println("║ DOUBLE ║ Float64 ║"); - System.out.println("║ STRING ║ Utf8 ║"); - System.out.println("║ BYTES ║ Binary ║"); - System.out.println("║ DATE ║ Date32 ║"); - System.out.println("║ TIMESTAMP ║ Timestamp ║"); - System.out.println("║ ARRAY ║ FixedSizeList ║"); - System.out.println("╚════════════════════╩═══════════════════╝\n"); - - System.out.println("【Configuration Options】"); - System.out.println("╔═══════════════════════════╦════════════════════════════════╗"); - System.out.println("║ Option ║ Description ║"); - System.out.println("╠═══════════════════════════╬════════════════════════════════╣"); - System.out.println("║ path ║ Dataset path (required) ║"); - System.out.println("║ write.batch-size ║ Write batch size (default 1024)║"); - System.out.println("║ write.mode ║ Write mode: append/overwrite ║"); - System.out.println("║ read.batch-size ║ Read batch size (default 1024) ║"); - System.out.println("║ index.type ║ Index type: IVF_PQ/IVF_HNSW/IVF_FLAT║"); - System.out.println("║ index.column ║ Index column name ║"); - System.out.println("║ index.num-partitions ║ IVF partitions (default 256) ║"); - System.out.println("║ vector.column ║ Vector column name ║"); - System.out.println("║ vector.metric ║ Distance metric: L2/COSINE/DOT ║"); - System.out.println("║ vector.nprobes ║ Search probes (default 20) ║"); - System.out.println("╚═══════════════════════════╩════════════════════════════════╝\n"); - - System.out.println("✅ Quick reference completed!"); + // ============================================ + // Step 5: Use COSINE cosine similarity search + // ============================================ + System.out.println("\n-- Step 5: Use COSINE cosine similarity search"); + + String cosineSearchSql = + String.format( + "SELECT * FROM TABLE(\n" + + " vector_search(\n" + + " '%s',\n" + + " 'embedding',\n" + + " ARRAY[0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1],\n" + + " 3,\n" + + " 'COSINE' -- Cosine similarity\n" + + " )\n" + + ")", + datasetPath); + + System.out.println(cosineSearchSql); + System.out.println(); + System.out.println("📊 Search results (sorted by cosine distance):"); + System.out.println("---------------------------------------------------"); + + try { + tableEnv.executeSql(cosineSearchSql).print(); + } catch (Exception e) { + System.out.println("⚠️ Execution error: " + e.getMessage()); } + + // ============================================ + // Step 6: Combine vector search with other queries + // ============================================ + System.out.println("\n-- Step 6: Combine vector search with other queries (LATERAL TABLE)"); + + String lateralSearchSql = + String.format( + "-- First query data, then perform vector search based on results\n" + + "SELECT \n" + + " v.id,\n" + + " v.title,\n" + + " v._distance as similarity_distance\n" + + "FROM TABLE(\n" + + " vector_search('%s', 'embedding', ARRAY[0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85], 5, 'L2')\n" + + ") AS v\n" + + "WHERE v._distance < 1.0 -- Only return results with distance less than 1", + datasetPath); + + System.out.println(lateralSearchSql); + System.out.println(); + + // ============================================ + // Print configuration parameter descriptions + // ============================================ + System.out.println("\n========== IVF_PQ Index Configuration Parameter Description =========="); + System.out.println( + "╔═════════════════════════════╦════════════════════════════════════════════════════╗"); + System.out.println( + "║ Configuration ║ Description ║"); + System.out.println( + "╠═════════════════════════════╬════════════════════════════════════════════════════╣"); + System.out.println( + "║ index.type = 'IVF_PQ' ║ Use IVF_PQ index type ║"); + System.out.println( + "║ index.column ║ Vector column name to build index on ║"); + System.out.println( + "║ index.num-partitions ║ IVF partition count, recommend: sqrt(n) to 4*sqrt(n)║"); + System.out.println( + "║ index.num-sub-vectors ║ PQ sub-vector count, must divide vector dimension ║"); + System.out.println( + "║ index.num-bits ║ PQ encoding bits, default 8 (256 cluster centers) ║"); + System.out.println( + "║ vector.metric ║ Distance metric: L2(Euclidean)/COSINE/DOT(Dot product)║"); + System.out.println( + "║ vector.nprobes ║ Number of partitions to probe during search ║"); + System.out.println( + "╚═════════════════════════════╩════════════════════════════════════════════════════╝"); + + System.out.println("\n========== Distance Metric Type Description =========="); + System.out.println( + "╔════════════════╦════════════════════════════════════════════════════════════════╗"); + System.out.println( + "║ Metric Type ║ Description ║"); + System.out.println( + "╠════════════════╬════════════════════════════════════════════════════════════════╣"); + System.out.println( + "║ L2 ║ Euclidean distance, smaller = more similar, for dense vectors ║"); + System.out.println( + "║ COSINE ║ Cosine distance, range [0,2], smaller = more similar, for text║"); + System.out.println( + "║ DOT ║ Negative dot product, smaller = more similar (needs normalization)║"); + System.out.println( + "╚════════════════╩════════════════════════════════════════════════════════════════╝"); + + System.out.println("\n✅ Vector search IVF_PQ example completed!\n"); + } + + @Test + @DisplayName("9.2 Different Index Types Comparison Example") + void testDifferentIndexTypesDetailed() throws Exception { + System.out.println("========== Different Vector Index Types Comparison =========="); + + // Use relative path based on project root + Path basePath = Paths.get(System.getProperty("user.dir"), "test-data"); + + // ============================================ + // IVF_PQ Index - For large-scale data, low memory footprint + // ============================================ + System.out.println("【1. IVF_PQ Index】- Recommended for large-scale data"); + System.out.println("Pros: Low memory footprint, fast search speed"); + System.out.println("Cons: Lower accuracy (quantization loss)"); + System.out.println(); + + String ivfPqSql = + String.format( + "CREATE TABLE ivf_pq_vectors (\n" + + " id BIGINT,\n" + + " embedding ARRAY\n" + + ") WITH (\n" + + " 'connector' = 'lance',\n" + + " 'path' = '%s',\n" + + " 'index.type' = 'IVF_PQ',\n" + + " 'index.column' = 'embedding',\n" + + " 'index.num-partitions' = '256', -- IVF partition count\n" + + " 'index.num-sub-vectors' = '16', -- PQ sub-vector count\n" + + " 'index.num-bits' = '8', -- Encoding bits per sub-vector\n" + + " 'vector.metric' = 'L2'\n" + + ")", + basePath.resolve("ivf-pq-demo")); + + System.out.println(ivfPqSql); + System.out.println(); + + // ============================================ + // IVF_HNSW Index - High accuracy search + // ============================================ + System.out.println("【2. IVF_HNSW Index】- Recommended for high accuracy requirements"); + System.out.println("Pros: High search accuracy"); + System.out.println("Cons: Higher memory footprint, slower index building"); + System.out.println(); + + String ivfHnswSql = + String.format( + "CREATE TABLE ivf_hnsw_vectors (\n" + + " id BIGINT,\n" + + " embedding ARRAY\n" + + ") WITH (\n" + + " 'connector' = 'lance',\n" + + " 'path' = '%s',\n" + + " 'index.type' = 'IVF_HNSW',\n" + + " 'index.column' = 'embedding',\n" + + " 'index.num-partitions' = '256', -- IVF partition count\n" + + " 'index.hnsw-m' = '16', -- HNSW connections per level\n" + + " 'index.hnsw-ef-construction' = '100', -- Candidate set size during construction\n" + + " 'vector.metric' = 'COSINE',\n" + + " 'vector.ef' = '50' -- Candidate set size during search\n" + + ")", + basePath.resolve("ivf-hnsw-demo")); + + System.out.println(ivfHnswSql); + System.out.println(); + + // ============================================ + // IVF_FLAT Index - Highest accuracy, brute force search + // ============================================ + System.out.println("【3. IVF_FLAT Index】- Highest accuracy"); + System.out.println("Pros: 100% search accuracy (lossless)"); + System.out.println("Cons: Slower search speed, suitable for small datasets"); + System.out.println(); + + String ivfFlatSql = + String.format( + "CREATE TABLE ivf_flat_vectors (\n" + + " id BIGINT,\n" + + " embedding ARRAY\n" + + ") WITH (\n" + + " 'connector' = 'lance',\n" + + " 'path' = '%s',\n" + + " 'index.type' = 'IVF_FLAT',\n" + + " 'index.column' = 'embedding',\n" + + " 'index.num-partitions' = '128', -- IVF partition count\n" + + " 'vector.metric' = 'DOT',\n" + + " 'vector.nprobes' = '32' -- Number of partitions to probe during search\n" + + ")", + basePath.resolve("ivf-flat-demo")); + + System.out.println(ivfFlatSql); + System.out.println(); + + // ============================================ + // Index Selection Recommendations + // ============================================ + System.out.println("========== Index Selection Recommendations =========="); + System.out.println( + "╔═══════════════════╦════════════════╦═══════════════╦════════════════════════════════╗"); + System.out.println( + "║ Index Type ║ Data Scale ║ Accuracy ║ Use Case ║"); + System.out.println( + "╠═══════════════════╬════════════════╬═══════════════╬════════════════════════════════╣"); + System.out.println( + "║ IVF_PQ ║ 1M+ ║ Medium ║ Large-scale recommendation, image search║"); + System.out.println( + "║ IVF_HNSW ║ 100K-1M ║ High ║ Semantic search, Q&A systems ║"); + System.out.println( + "║ IVF_FLAT ║ <100K ║ Highest ║ Small-scale high-precision scenarios║"); + System.out.println( + "╚═══════════════════╩════════════════╩═══════════════╩════════════════════════════════╝"); + + System.out.println("\n✅ Index type comparison example completed!\n"); + } + + @Test + @DisplayName("10. SQL Syntax Quick Reference") + void testSqlQuickReference() { + System.out.println("========================================"); + System.out.println(" Flink SQL Lance Connector Quick Reference"); + System.out.println("========================================\n"); + + System.out.println("【Create Table】"); + System.out.println("CREATE TABLE table_name ("); + System.out.println(" column_name data_type,"); + System.out.println(" embedding ARRAY"); + System.out.println(") WITH ("); + System.out.println(" 'connector' = 'lance',"); + System.out.println(" 'path' = '/path/to/dataset'"); + System.out.println(");\n"); + + System.out.println("【Insert Data】"); + System.out.println("INSERT INTO table_name VALUES (1, 'text', ARRAY[0.1, 0.2, 0.3]);\n"); + + System.out.println("【Query Data】"); + System.out.println("SELECT * FROM table_name WHERE condition;\n"); + + System.out.println("【Create Catalog】"); + System.out.println("CREATE CATALOG lance_catalog WITH ("); + System.out.println(" 'type' = 'lance',"); + System.out.println(" 'warehouse' = '/path/to/warehouse'"); + System.out.println(");\n"); + + System.out.println("【Data Type Mapping】"); + System.out.println("╔════════════════════╦═══════════════════╗"); + System.out.println("║ Flink SQL Type ║ Lance Type ║"); + System.out.println("╠════════════════════╬═══════════════════╣"); + System.out.println("║ BOOLEAN ║ Bool ║"); + System.out.println("║ TINYINT ║ Int8 ║"); + System.out.println("║ SMALLINT ║ Int16 ║"); + System.out.println("║ INT ║ Int32 ║"); + System.out.println("║ BIGINT ║ Int64 ║"); + System.out.println("║ FLOAT ║ Float32 ║"); + System.out.println("║ DOUBLE ║ Float64 ║"); + System.out.println("║ STRING ║ Utf8 ║"); + System.out.println("║ BYTES ║ Binary ║"); + System.out.println("║ DATE ║ Date32 ║"); + System.out.println("║ TIMESTAMP ║ Timestamp ║"); + System.out.println("║ ARRAY ║ FixedSizeList ║"); + System.out.println("╚════════════════════╩═══════════════════╝\n"); + + System.out.println("【Configuration Options】"); + System.out.println("╔═══════════════════════════╦════════════════════════════════╗"); + System.out.println("║ Option ║ Description ║"); + System.out.println("╠═══════════════════════════╬════════════════════════════════╣"); + System.out.println("║ path ║ Dataset path (required) ║"); + System.out.println("║ write.batch-size ║ Write batch size (default 1024)║"); + System.out.println("║ write.mode ║ Write mode: append/overwrite ║"); + System.out.println("║ read.batch-size ║ Read batch size (default 1024) ║"); + System.out.println("║ index.type ║ Index type: IVF_PQ/IVF_HNSW/IVF_FLAT║"); + System.out.println("║ index.column ║ Index column name ║"); + System.out.println("║ index.num-partitions ║ IVF partitions (default 256) ║"); + System.out.println("║ vector.column ║ Vector column name ║"); + System.out.println("║ vector.metric ║ Distance metric: L2/COSINE/DOT ║"); + System.out.println("║ vector.nprobes ║ Search probes (default 20) ║"); + System.out.println("╚═══════════════════════════╩════════════════════════════════╝\n"); + + System.out.println("✅ Quick reference completed!"); + } } diff --git a/src/test/java/org/apache/flink/connector/lance/table/LanceAggregatePushDownTest.java b/src/test/java/org/apache/flink/connector/lance/table/LanceAggregatePushDownTest.java index 56de5ba..7d4df3b 100644 --- a/src/test/java/org/apache/flink/connector/lance/table/LanceAggregatePushDownTest.java +++ b/src/test/java/org/apache/flink/connector/lance/table/LanceAggregatePushDownTest.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.table; import org.apache.flink.connector.lance.aggregate.AggregateInfo; @@ -32,326 +27,314 @@ import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import static org.junit.jupiter.api.Assertions.*; -/** - * LanceDynamicTableSource aggregate push-down tests - */ +/** LanceDynamicTableSource aggregate push-down tests */ @DisplayName("LanceDynamicTableSource Aggregate Push-Down Tests") class LanceAggregatePushDownTest { - private LanceOptions options; - private DataType physicalDataType; + private LanceOptions options; + private DataType physicalDataType; - @BeforeEach - void setUp() { - options = LanceOptions.builder() - .path("/tmp/test_lance_dataset") - .readBatchSize(1024) - .build(); + @BeforeEach + void setUp() { + options = LanceOptions.builder().path("/tmp/test_lance_dataset").readBatchSize(1024).build(); - // Create test physical data type - // Schema: (id INT, name VARCHAR, category VARCHAR, amount DOUBLE, quantity INT) - RowType rowType = new RowType(Arrays.asList( + // Create test physical data type + // Schema: (id INT, name VARCHAR, category VARCHAR, amount DOUBLE, quantity INT) + RowType rowType = + new RowType( + Arrays.asList( new RowType.RowField("id", new IntType()), new RowType.RowField("name", new VarCharType(100)), new RowType.RowField("category", new VarCharType(50)), new RowType.RowField("amount", new DoubleType()), - new RowType.RowField("quantity", new IntType()) - )); - physicalDataType = TypeConversions.fromLogicalToDataType(rowType); + new RowType.RowField("quantity", new IntType()))); + physicalDataType = TypeConversions.fromLogicalToDataType(rowType); + } + + // ==================== Aggregate Push-Down Interface Tests ==================== + + @Nested + @DisplayName("applyAggregates Method Tests") + class ApplyAggregatesTests { + + // Note: Since applyAggregates requires real AggregateExpression objects, + // we mainly test aggregate info storage and state management here + + @Test + @DisplayName("Initial state should have no aggregate push-down") + void testInitialState() { + LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); + + assertFalse(source.isAggregatePushDownAccepted()); + assertNull(source.getAggregateInfo()); + } + + @Test + @DisplayName("copy should correctly copy aggregate state") + void testCopyAggregateState() { + LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); + + // Copy source + LanceDynamicTableSource copied = (LanceDynamicTableSource) source.copy(); + + // Verify copied state + assertFalse(copied.isAggregatePushDownAccepted()); + assertNull(copied.getAggregateInfo()); + assertNotSame(source, copied); + } + + @Test + @DisplayName("asSummaryString should return correct summary") + void testAsSummaryString() { + LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); + + String summary = source.asSummaryString(); + + assertEquals("Lance Table Source", summary); + } + } + + // ==================== AggregateInfo Integration Tests ==================== + + @Nested + @DisplayName("AggregateInfo Integration Tests") + class AggregateInfoIntegrationTests { + + @Test + @DisplayName("Simple COUNT(*) aggregate info build") + void testSimpleCountStarAggregateInfo() { + AggregateInfo aggInfo = AggregateInfo.builder().addCountStar("cnt").build(); + + assertTrue(aggInfo.isSimpleCountStar()); + assertFalse(aggInfo.hasGroupBy()); + assertEquals(1, aggInfo.getAggregateCalls().size()); + } + + @Test + @DisplayName("Aggregate info build with GROUP BY") + void testGroupByAggregateInfo() { + AggregateInfo aggInfo = + AggregateInfo.builder() + .addSum("amount", "total_amount") + .addAvg("amount", "avg_amount") + .groupBy("category") + .groupByFieldIndices(new int[] {2}) // category at index 2 + .build(); + + assertFalse(aggInfo.isSimpleCountStar()); + assertTrue(aggInfo.hasGroupBy()); + assertEquals(2, aggInfo.getAggregateCalls().size()); + assertEquals(Collections.singletonList("category"), aggInfo.getGroupByColumns()); } - // ==================== Aggregate Push-Down Interface Tests ==================== - - @Nested - @DisplayName("applyAggregates Method Tests") - class ApplyAggregatesTests { - - // Note: Since applyAggregates requires real AggregateExpression objects, - // we mainly test aggregate info storage and state management here - - @Test - @DisplayName("Initial state should have no aggregate push-down") - void testInitialState() { - LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); - - assertFalse(source.isAggregatePushDownAccepted()); - assertNull(source.getAggregateInfo()); - } - - @Test - @DisplayName("copy should correctly copy aggregate state") - void testCopyAggregateState() { - LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); - - // Copy source - LanceDynamicTableSource copied = (LanceDynamicTableSource) source.copy(); - - // Verify copied state - assertFalse(copied.isAggregatePushDownAccepted()); - assertNull(copied.getAggregateInfo()); - assertNotSame(source, copied); - } - - @Test - @DisplayName("asSummaryString should return correct summary") - void testAsSummaryString() { - LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); - - String summary = source.asSummaryString(); - - assertEquals("Lance Table Source", summary); - } + @Test + @DisplayName("Multiple aggregates info build") + void testMultipleAggregatesInfo() { + AggregateInfo aggInfo = + AggregateInfo.builder() + .addCountStar("cnt") + .addSum("amount", "sum_amount") + .addAvg("amount", "avg_amount") + .addMin("amount", "min_amount") + .addMax("amount", "max_amount") + .build(); + + assertEquals(5, aggInfo.getAggregateCalls().size()); + + // Verify each aggregate function type + List calls = aggInfo.getAggregateCalls(); + assertEquals(AggregateInfo.AggregateFunction.COUNT, calls.get(0).getFunction()); + assertEquals(AggregateInfo.AggregateFunction.SUM, calls.get(1).getFunction()); + assertEquals(AggregateInfo.AggregateFunction.AVG, calls.get(2).getFunction()); + assertEquals(AggregateInfo.AggregateFunction.MIN, calls.get(3).getFunction()); + assertEquals(AggregateInfo.AggregateFunction.MAX, calls.get(4).getFunction()); } - // ==================== AggregateInfo Integration Tests ==================== - - @Nested - @DisplayName("AggregateInfo Integration Tests") - class AggregateInfoIntegrationTests { - - @Test - @DisplayName("Simple COUNT(*) aggregate info build") - void testSimpleCountStarAggregateInfo() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .build(); - - assertTrue(aggInfo.isSimpleCountStar()); - assertFalse(aggInfo.hasGroupBy()); - assertEquals(1, aggInfo.getAggregateCalls().size()); - } - - @Test - @DisplayName("Aggregate info build with GROUP BY") - void testGroupByAggregateInfo() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addSum("amount", "total_amount") - .addAvg("amount", "avg_amount") - .groupBy("category") - .groupByFieldIndices(new int[]{2}) // category at index 2 - .build(); - - assertFalse(aggInfo.isSimpleCountStar()); - assertTrue(aggInfo.hasGroupBy()); - assertEquals(2, aggInfo.getAggregateCalls().size()); - assertEquals(Collections.singletonList("category"), aggInfo.getGroupByColumns()); - } - - @Test - @DisplayName("Multiple aggregates info build") - void testMultipleAggregatesInfo() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .addSum("amount", "sum_amount") - .addAvg("amount", "avg_amount") - .addMin("amount", "min_amount") - .addMax("amount", "max_amount") - .build(); - - assertEquals(5, aggInfo.getAggregateCalls().size()); - - // Verify each aggregate function type - List calls = aggInfo.getAggregateCalls(); - assertEquals(AggregateInfo.AggregateFunction.COUNT, calls.get(0).getFunction()); - assertEquals(AggregateInfo.AggregateFunction.SUM, calls.get(1).getFunction()); - assertEquals(AggregateInfo.AggregateFunction.AVG, calls.get(2).getFunction()); - assertEquals(AggregateInfo.AggregateFunction.MIN, calls.get(3).getFunction()); - assertEquals(AggregateInfo.AggregateFunction.MAX, calls.get(4).getFunction()); - } - - @Test - @DisplayName("getRequiredColumns should return correct columns") - void testGetRequiredColumns() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addSum("amount", "sum_amount") - .addAvg("quantity", "avg_quantity") - .groupBy("category") - .build(); - - List required = aggInfo.getRequiredColumns(); - - assertTrue(required.contains("category")); - assertTrue(required.contains("amount")); - assertTrue(required.contains("quantity")); - } + @Test + @DisplayName("getRequiredColumns should return correct columns") + void testGetRequiredColumns() { + AggregateInfo aggInfo = + AggregateInfo.builder() + .addSum("amount", "sum_amount") + .addAvg("quantity", "avg_quantity") + .groupBy("category") + .build(); + + List required = aggInfo.getRequiredColumns(); + + assertTrue(required.contains("category")); + assertTrue(required.contains("amount")); + assertTrue(required.contains("quantity")); } + } + + // ==================== Combined Functionality Tests ==================== + + @Nested + @DisplayName("Combined Functionality Tests") + class CombinedFunctionalityTests { + + @Test + @DisplayName("Aggregate push-down with filter push-down combination") + void testAggregatePushDownWithFilter() { + LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); - // ==================== Combined Functionality Tests ==================== - - @Nested - @DisplayName("Combined Functionality Tests") - class CombinedFunctionalityTests { - - @Test - @DisplayName("Aggregate push-down with filter push-down combination") - void testAggregatePushDownWithFilter() { - LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); - - // Simulate adding filter conditions (through internal filters list) - // Note: Actual filter push-down is done through applyFilters method - - // Verify source can support both filter and aggregate push-down - assertNotNull(source.getOptions()); - } - - @Test - @DisplayName("Aggregate push-down with column pruning combination") - void testAggregatePushDownWithProjection() { - LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); - - // Apply column pruning - source.applyProjection(new int[][]{{0}, {3}, {4}}); // id, amount, quantity - - // Verify source still works correctly - assertNotNull(source.getOptions()); - } - - @Test - @DisplayName("Aggregate push-down with Limit combination") - void testAggregatePushDownWithLimit() { - LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); - - // Apply Limit - source.applyLimit(100); - - assertEquals(Long.valueOf(100), source.getLimit()); - } + // Simulate adding filter conditions (through internal filters list) + // Note: Actual filter push-down is done through applyFilters method + + // Verify source can support both filter and aggregate push-down + assertNotNull(source.getOptions()); } - // ==================== Edge Case Tests ==================== - - @Nested - @DisplayName("Edge Case Tests") - class EdgeCaseTests { - - @Test - @DisplayName("Multiple group by columns should be handled correctly") - void testMultipleGroupByColumns() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .groupBy("category", "name") - .groupByFieldIndices(new int[]{2, 1}) - .build(); - - assertEquals(2, aggInfo.getGroupByColumns().size()); - assertArrayEquals(new int[]{2, 1}, aggInfo.getGroupByFieldIndices()); - } - - @Test - @DisplayName("Multiple aggregates on same column should be handled correctly") - void testMultipleAggregatesOnSameColumn() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addSum("amount", "sum_amount") - .addAvg("amount", "avg_amount") - .addMin("amount", "min_amount") - .addMax("amount", "max_amount") - .addCount("amount", "count_amount") - .build(); - - assertEquals(5, aggInfo.getAggregateCalls().size()); - - // Verify getRequiredColumns contains amount only once - List required = aggInfo.getRequiredColumns(); - long amountCount = required.stream().filter(c -> c.equals("amount")).count(); - assertEquals(1, amountCount); - } - - @Test - @DisplayName("Empty group by set should be handled correctly") - void testEmptyGroupBy() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .build(); - - assertFalse(aggInfo.hasGroupBy()); - assertTrue(aggInfo.getGroupByColumns().isEmpty()); - assertEquals(0, aggInfo.getGroupByFieldIndices().length); - } + @Test + @DisplayName("Aggregate push-down with column pruning combination") + void testAggregatePushDownWithProjection() { + LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); + + // Apply column pruning + source.applyProjection(new int[][] {{0}, {3}, {4}}); // id, amount, quantity + + // Verify source still works correctly + assertNotNull(source.getOptions()); + } + + @Test + @DisplayName("Aggregate push-down with Limit combination") + void testAggregatePushDownWithLimit() { + LanceDynamicTableSource source = new LanceDynamicTableSource(options, physicalDataType); + + // Apply Limit + source.applyLimit(100); + + assertEquals(Long.valueOf(100), source.getLimit()); + } + } + + // ==================== Edge Case Tests ==================== + + @Nested + @DisplayName("Edge Case Tests") + class EdgeCaseTests { + + @Test + @DisplayName("Multiple group by columns should be handled correctly") + void testMultipleGroupByColumns() { + AggregateInfo aggInfo = + AggregateInfo.builder() + .addCountStar("cnt") + .groupBy("category", "name") + .groupByFieldIndices(new int[] {2, 1}) + .build(); + + assertEquals(2, aggInfo.getGroupByColumns().size()); + assertArrayEquals(new int[] {2, 1}, aggInfo.getGroupByFieldIndices()); + } + + @Test + @DisplayName("Multiple aggregates on same column should be handled correctly") + void testMultipleAggregatesOnSameColumn() { + AggregateInfo aggInfo = + AggregateInfo.builder() + .addSum("amount", "sum_amount") + .addAvg("amount", "avg_amount") + .addMin("amount", "min_amount") + .addMax("amount", "max_amount") + .addCount("amount", "count_amount") + .build(); + + assertEquals(5, aggInfo.getAggregateCalls().size()); + + // Verify getRequiredColumns contains amount only once + List required = aggInfo.getRequiredColumns(); + long amountCount = required.stream().filter(c -> c.equals("amount")).count(); + assertEquals(1, amountCount); + } + + @Test + @DisplayName("Empty group by set should be handled correctly") + void testEmptyGroupBy() { + AggregateInfo aggInfo = AggregateInfo.builder().addCountStar("cnt").build(); + + assertFalse(aggInfo.hasGroupBy()); + assertTrue(aggInfo.getGroupByColumns().isEmpty()); + assertEquals(0, aggInfo.getGroupByFieldIndices().length); + } + } + + // ==================== Aggregate Function Support Tests ==================== + + @Nested + @DisplayName("Aggregate Function Support Tests") + class AggregateFunctionSupportTests { + + @Test + @DisplayName("COUNT function should be supported") + void testCountSupport() { + AggregateInfo aggInfo = AggregateInfo.builder().addCountStar("cnt").build(); + + AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); + assertEquals(AggregateInfo.AggregateFunction.COUNT, call.getFunction()); + assertTrue(call.isCountStar()); + } + + @Test + @DisplayName("SUM function should be supported") + void testSumSupport() { + AggregateInfo aggInfo = AggregateInfo.builder().addSum("amount", "sum_amount").build(); + + AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); + assertEquals(AggregateInfo.AggregateFunction.SUM, call.getFunction()); + assertEquals("amount", call.getColumn()); + } + + @Test + @DisplayName("AVG function should be supported") + void testAvgSupport() { + AggregateInfo aggInfo = AggregateInfo.builder().addAvg("amount", "avg_amount").build(); + + AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); + assertEquals(AggregateInfo.AggregateFunction.AVG, call.getFunction()); + assertEquals("amount", call.getColumn()); + } + + @Test + @DisplayName("MIN function should be supported") + void testMinSupport() { + AggregateInfo aggInfo = AggregateInfo.builder().addMin("amount", "min_amount").build(); + + AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); + assertEquals(AggregateInfo.AggregateFunction.MIN, call.getFunction()); + assertEquals("amount", call.getColumn()); + } + + @Test + @DisplayName("MAX function should be supported") + void testMaxSupport() { + AggregateInfo aggInfo = AggregateInfo.builder().addMax("amount", "max_amount").build(); + + AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); + assertEquals(AggregateInfo.AggregateFunction.MAX, call.getFunction()); + assertEquals("amount", call.getColumn()); } - // ==================== Aggregate Function Support Tests ==================== - - @Nested - @DisplayName("Aggregate Function Support Tests") - class AggregateFunctionSupportTests { - - @Test - @DisplayName("COUNT function should be supported") - void testCountSupport() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addCountStar("cnt") - .build(); - - AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); - assertEquals(AggregateInfo.AggregateFunction.COUNT, call.getFunction()); - assertTrue(call.isCountStar()); - } - - @Test - @DisplayName("SUM function should be supported") - void testSumSupport() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addSum("amount", "sum_amount") - .build(); - - AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); - assertEquals(AggregateInfo.AggregateFunction.SUM, call.getFunction()); - assertEquals("amount", call.getColumn()); - } - - @Test - @DisplayName("AVG function should be supported") - void testAvgSupport() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addAvg("amount", "avg_amount") - .build(); - - AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); - assertEquals(AggregateInfo.AggregateFunction.AVG, call.getFunction()); - assertEquals("amount", call.getColumn()); - } - - @Test - @DisplayName("MIN function should be supported") - void testMinSupport() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addMin("amount", "min_amount") - .build(); - - AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); - assertEquals(AggregateInfo.AggregateFunction.MIN, call.getFunction()); - assertEquals("amount", call.getColumn()); - } - - @Test - @DisplayName("MAX function should be supported") - void testMaxSupport() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addMax("amount", "max_amount") - .build(); - - AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); - assertEquals(AggregateInfo.AggregateFunction.MAX, call.getFunction()); - assertEquals("amount", call.getColumn()); - } - - @Test - @DisplayName("COUNT DISTINCT function should be supported") - void testCountDistinctSupport() { - AggregateInfo aggInfo = AggregateInfo.builder() - .addAggregateCall(AggregateInfo.AggregateFunction.COUNT_DISTINCT, "category", "distinct_cnt") - .build(); - - AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); - assertEquals(AggregateInfo.AggregateFunction.COUNT_DISTINCT, call.getFunction()); - assertEquals("category", call.getColumn()); - } + @Test + @DisplayName("COUNT DISTINCT function should be supported") + void testCountDistinctSupport() { + AggregateInfo aggInfo = + AggregateInfo.builder() + .addAggregateCall( + AggregateInfo.AggregateFunction.COUNT_DISTINCT, "category", "distinct_cnt") + .build(); + + AggregateInfo.AggregateCall call = aggInfo.getAggregateCalls().get(0); + assertEquals(AggregateInfo.AggregateFunction.COUNT_DISTINCT, call.getFunction()); + assertEquals("category", call.getColumn()); } + } } diff --git a/src/test/java/org/apache/flink/connector/lance/table/LanceCatalogS3Test.java b/src/test/java/org/apache/flink/connector/lance/table/LanceCatalogS3Test.java index 4d0be5b..05351a9 100644 --- a/src/test/java/org/apache/flink/connector/lance/table/LanceCatalogS3Test.java +++ b/src/test/java/org/apache/flink/connector/lance/table/LanceCatalogS3Test.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.table; import org.apache.flink.table.api.EnvironmentSettings; @@ -47,616 +42,608 @@ /** * Lance Catalog S3 integration tests. - * + * *

This test class is divided into two parts: + * *

    - *
  • Unit tests that don't require MinIO connection (always run)
  • - *
  • Integration tests that require MinIO connection (require external MinIO service configuration)
  • + *
  • Unit tests that don't require MinIO connection (always run) + *
  • Integration tests that require MinIO connection (require external MinIO service + * configuration) *
- * + * *

To run tests that require MinIO, set the following environment variables: + * *

    - *
  • MINIO_ENDPOINT - MinIO service address, e.g., http://localhost:9000
  • - *
  • MINIO_ACCESS_KEY - MinIO access key (default: minioadmin)
  • - *
  • MINIO_SECRET_KEY - MinIO secret key (default: minioadmin)
  • - *
  • MINIO_BUCKET - Test bucket name (default: lance-test-bucket)
  • + *
  • MINIO_ENDPOINT - MinIO service address, e.g., http://localhost:9000 + *
  • MINIO_ACCESS_KEY - MinIO access key (default: minioadmin) + *
  • MINIO_SECRET_KEY - MinIO secret key (default: minioadmin) + *
  • MINIO_BUCKET - Test bucket name (default: lance-test-bucket) *
- * + * *

Quick way to start MinIO (using Docker): + * *

  * docker run -p 9000:9000 -p 9001:9001 \
  *   -e "MINIO_ROOT_USER=minioadmin" \
  *   -e "MINIO_ROOT_PASSWORD=minioadmin" \
  *   minio/minio server /data --console-address ":9001"
  * 
- * + * *

Or use a locally installed MinIO service. */ class LanceCatalogS3Test { - private static final Logger LOG = LoggerFactory.getLogger(LanceCatalogS3Test.class); - - // MinIO configuration - read from environment variables or system properties - private static String minioEndpoint; - private static String minioAccessKey; - private static String minioSecretKey; - private static String testBucket; - private static boolean minioAvailable = false; - - /** - * Check if MinIO is available - */ - static boolean isMinioAvailable() { - return minioAvailable; - } - - @BeforeAll - static void initMinioConfig() { - // Read configuration from environment variables - minioEndpoint = getConfigValue("MINIO_ENDPOINT", "minio.endpoint", null); - minioAccessKey = getConfigValue("MINIO_ACCESS_KEY", "minio.access.key", "minioadmin"); - minioSecretKey = getConfigValue("MINIO_SECRET_KEY", "minio.secret.key", "minioadmin"); - testBucket = getConfigValue("MINIO_BUCKET", "minio.bucket", "lance-test-bucket"); - - if (minioEndpoint != null && !minioEndpoint.isEmpty()) { - LOG.info("MinIO configuration detected:"); - LOG.info(" Endpoint: {}", minioEndpoint); - LOG.info(" Bucket: {}", testBucket); - - // Try to connect to MinIO to verify availability - try { - minioAvailable = checkMinioConnection(); - if (minioAvailable) { - LOG.info("MinIO connection verification successful, integration tests will be enabled"); - } else { - LOG.warn("MinIO connection verification failed, integration tests will be skipped"); - } - } catch (Exception e) { - LOG.warn("MinIO connection check failed: {}, integration tests will be skipped", e.getMessage()); - minioAvailable = false; - } + private static final Logger LOG = LoggerFactory.getLogger(LanceCatalogS3Test.class); + + // MinIO configuration - read from environment variables or system properties + private static String minioEndpoint; + private static String minioAccessKey; + private static String minioSecretKey; + private static String testBucket; + private static boolean minioAvailable = false; + + /** Check if MinIO is available */ + static boolean isMinioAvailable() { + return minioAvailable; + } + + @BeforeAll + static void initMinioConfig() { + // Read configuration from environment variables + minioEndpoint = getConfigValue("MINIO_ENDPOINT", "minio.endpoint", null); + minioAccessKey = getConfigValue("MINIO_ACCESS_KEY", "minio.access.key", "minioadmin"); + minioSecretKey = getConfigValue("MINIO_SECRET_KEY", "minio.secret.key", "minioadmin"); + testBucket = getConfigValue("MINIO_BUCKET", "minio.bucket", "lance-test-bucket"); + + if (minioEndpoint != null && !minioEndpoint.isEmpty()) { + LOG.info("MinIO configuration detected:"); + LOG.info(" Endpoint: {}", minioEndpoint); + LOG.info(" Bucket: {}", testBucket); + + // Try to connect to MinIO to verify availability + try { + minioAvailable = checkMinioConnection(); + if (minioAvailable) { + LOG.info("MinIO connection verification successful, integration tests will be enabled"); } else { - LOG.info("No MinIO configuration detected (MINIO_ENDPOINT environment variable not set), integration tests will be skipped"); - LOG.info("To enable MinIO integration tests, set the following environment variables:"); - LOG.info(" export MINIO_ENDPOINT=http://localhost:9000"); - LOG.info(" export MINIO_ACCESS_KEY=minioadmin"); - LOG.info(" export MINIO_SECRET_KEY=minioadmin"); - LOG.info(" export MINIO_BUCKET=lance-test-bucket"); - } + LOG.warn("MinIO connection verification failed, integration tests will be skipped"); + } + } catch (Exception e) { + LOG.warn( + "MinIO connection check failed: {}, integration tests will be skipped", e.getMessage()); + minioAvailable = false; + } + } else { + LOG.info( + "No MinIO configuration detected (MINIO_ENDPOINT environment variable not set), integration tests will be skipped"); + LOG.info("To enable MinIO integration tests, set the following environment variables:"); + LOG.info(" export MINIO_ENDPOINT=http://localhost:9000"); + LOG.info(" export MINIO_ACCESS_KEY=minioadmin"); + LOG.info(" export MINIO_SECRET_KEY=minioadmin"); + LOG.info(" export MINIO_BUCKET=lance-test-bucket"); } + } - /** - * Get configuration value from environment variable or system property - */ - private static String getConfigValue(String envKey, String propKey, String defaultValue) { - String value = System.getenv(envKey); - if (value == null || value.isEmpty()) { - value = System.getProperty(propKey, defaultValue); - } - return value; - } - - /** - * Check if MinIO connection is available - */ - private static boolean checkMinioConnection() { - try { - // Try to create a simple HTTP connection to check if MinIO service is available - java.net.URL url = new java.net.URL(minioEndpoint + "/minio/health/live"); - java.net.HttpURLConnection connection = (java.net.HttpURLConnection) url.openConnection(); - connection.setRequestMethod("GET"); - connection.setConnectTimeout(5000); - connection.setReadTimeout(5000); - int responseCode = connection.getResponseCode(); - connection.disconnect(); - return responseCode == 200; - } catch (Exception e) { - LOG.debug("MinIO health check failed: {}", e.getMessage()); - return false; - } + /** Get configuration value from environment variable or system property */ + private static String getConfigValue(String envKey, String propKey, String defaultValue) { + String value = System.getenv(envKey); + if (value == null || value.isEmpty()) { + value = System.getProperty(propKey, defaultValue); + } + return value; + } + + /** Check if MinIO connection is available */ + private static boolean checkMinioConnection() { + try { + // Try to create a simple HTTP connection to check if MinIO service is available + java.net.URL url = new java.net.URL(minioEndpoint + "/minio/health/live"); + java.net.HttpURLConnection connection = (java.net.HttpURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setConnectTimeout(5000); + connection.setReadTimeout(5000); + int responseCode = connection.getResponseCode(); + connection.disconnect(); + return responseCode == 200; + } catch (Exception e) { + LOG.debug("MinIO health check failed: {}", e.getMessage()); + return false; } + } - // ==================== Unit Tests That Don't Require MinIO (Always Run) ==================== + // ==================== Unit Tests That Don't Require MinIO (Always Run) ==================== - /** - * Unit tests that don't require MinIO connection - */ - @Nested - @DisplayName("Unit Tests - No MinIO Required") - class UnitTests { + /** Unit tests that don't require MinIO connection */ + @Nested + @DisplayName("Unit Tests - No MinIO Required") + class UnitTests { - // ==================== Remote Path Detection Tests ==================== + // ==================== Remote Path Detection Tests ==================== - @Test - @DisplayName("Test remote path detection - S3 protocol") - void testRemotePathDetectionS3() { - LanceCatalog catalog = new LanceCatalog("test", "default", "s3://bucket/path"); - assertThat(catalog.isRemoteStorage()).isTrue(); - } - - @Test - @DisplayName("Test remote path detection - S3A protocol") - void testRemotePathDetectionS3A() { - LanceCatalog catalog = new LanceCatalog("test", "default", "s3a://bucket/path"); - assertThat(catalog.isRemoteStorage()).isTrue(); - } + @Test + @DisplayName("Test remote path detection - S3 protocol") + void testRemotePathDetectionS3() { + LanceCatalog catalog = new LanceCatalog("test", "default", "s3://bucket/path"); + assertThat(catalog.isRemoteStorage()).isTrue(); + } - @Test - @DisplayName("Test remote path detection - GCS protocol") - void testRemotePathDetectionGCS() { - LanceCatalog catalog = new LanceCatalog("test", "default", "gs://bucket/path"); - assertThat(catalog.isRemoteStorage()).isTrue(); - } + @Test + @DisplayName("Test remote path detection - S3A protocol") + void testRemotePathDetectionS3A() { + LanceCatalog catalog = new LanceCatalog("test", "default", "s3a://bucket/path"); + assertThat(catalog.isRemoteStorage()).isTrue(); + } - @Test - @DisplayName("Test remote path detection - Azure protocol") - void testRemotePathDetectionAzure() { - LanceCatalog catalog = new LanceCatalog("test", "default", "az://container/path"); - assertThat(catalog.isRemoteStorage()).isTrue(); - } + @Test + @DisplayName("Test remote path detection - GCS protocol") + void testRemotePathDetectionGCS() { + LanceCatalog catalog = new LanceCatalog("test", "default", "gs://bucket/path"); + assertThat(catalog.isRemoteStorage()).isTrue(); + } - @Test - @DisplayName("Test local path detection") - void testLocalPathDetection() { - LanceCatalog catalog = new LanceCatalog("test", "default", "/tmp/local/path"); - assertThat(catalog.isRemoteStorage()).isFalse(); - } + @Test + @DisplayName("Test remote path detection - Azure protocol") + void testRemotePathDetectionAzure() { + LanceCatalog catalog = new LanceCatalog("test", "default", "az://container/path"); + assertThat(catalog.isRemoteStorage()).isTrue(); + } - // ==================== Factory Tests ==================== - - @Test - @DisplayName("Test LanceCatalogFactory S3 configuration options") - void testCatalogFactoryS3Options() { - LanceCatalogFactory factory = new LanceCatalogFactory(); - - Set optionalOptionKeys = new HashSet<>(); - factory.optionalOptions().forEach(opt -> optionalOptionKeys.add(opt.key())); - - // Verify S3 related options exist - assertThat(optionalOptionKeys).contains( - "s3-access-key", - "s3-secret-key", - "s3-region", - "s3-endpoint", - "s3-virtual-hosted-style", - "s3-allow-http" - ); - } + @Test + @DisplayName("Test local path detection") + void testLocalPathDetection() { + LanceCatalog catalog = new LanceCatalog("test", "default", "/tmp/local/path"); + assertThat(catalog.isRemoteStorage()).isFalse(); + } - @Test - @DisplayName("Test S3 configuration options default values") - void testS3ConfigOptionsDefaults() { - assertThat(LanceCatalogFactory.S3_VIRTUAL_HOSTED_STYLE.defaultValue()).isTrue(); - assertThat(LanceCatalogFactory.S3_ALLOW_HTTP.defaultValue()).isFalse(); - } + // ==================== Factory Tests ==================== + + @Test + @DisplayName("Test LanceCatalogFactory S3 configuration options") + void testCatalogFactoryS3Options() { + LanceCatalogFactory factory = new LanceCatalogFactory(); + + Set optionalOptionKeys = new HashSet<>(); + factory.optionalOptions().forEach(opt -> optionalOptionKeys.add(opt.key())); + + // Verify S3 related options exist + assertThat(optionalOptionKeys) + .contains( + "s3-access-key", + "s3-secret-key", + "s3-region", + "s3-endpoint", + "s3-virtual-hosted-style", + "s3-allow-http"); + } - @Test - @DisplayName("Test S3 configuration options descriptions") - void testS3ConfigOptionsDescriptions() { - // Verify configuration options exist and have descriptions - assertThat(LanceCatalogFactory.S3_ACCESS_KEY.key()).isEqualTo("s3-access-key"); - assertThat(LanceCatalogFactory.S3_SECRET_KEY.key()).isEqualTo("s3-secret-key"); - assertThat(LanceCatalogFactory.S3_REGION.key()).isEqualTo("s3-region"); - assertThat(LanceCatalogFactory.S3_ENDPOINT.key()).isEqualTo("s3-endpoint"); - - // Verify descriptions are not null - assertThat(LanceCatalogFactory.S3_ACCESS_KEY.description()).isNotNull(); - assertThat(LanceCatalogFactory.S3_SECRET_KEY.description()).isNotNull(); - assertThat(LanceCatalogFactory.S3_REGION.description()).isNotNull(); - assertThat(LanceCatalogFactory.S3_ENDPOINT.description()).isNotNull(); - } + @Test + @DisplayName("Test S3 configuration options default values") + void testS3ConfigOptionsDefaults() { + assertThat(LanceCatalogFactory.S3_VIRTUAL_HOSTED_STYLE.defaultValue()).isTrue(); + assertThat(LanceCatalogFactory.S3_ALLOW_HTTP.defaultValue()).isFalse(); + } - // ==================== Path Normalization Tests ==================== + @Test + @DisplayName("Test S3 configuration options descriptions") + void testS3ConfigOptionsDescriptions() { + // Verify configuration options exist and have descriptions + assertThat(LanceCatalogFactory.S3_ACCESS_KEY.key()).isEqualTo("s3-access-key"); + assertThat(LanceCatalogFactory.S3_SECRET_KEY.key()).isEqualTo("s3-secret-key"); + assertThat(LanceCatalogFactory.S3_REGION.key()).isEqualTo("s3-region"); + assertThat(LanceCatalogFactory.S3_ENDPOINT.key()).isEqualTo("s3-endpoint"); + + // Verify descriptions are not null + assertThat(LanceCatalogFactory.S3_ACCESS_KEY.description()).isNotNull(); + assertThat(LanceCatalogFactory.S3_SECRET_KEY.description()).isNotNull(); + assertThat(LanceCatalogFactory.S3_REGION.description()).isNotNull(); + assertThat(LanceCatalogFactory.S3_ENDPOINT.description()).isNotNull(); + } - @Test - @DisplayName("Test warehouse path normalization - remove trailing slashes") - void testWarehousePathNormalization() { - LanceCatalog catalog1 = new LanceCatalog("test", "default", "s3://bucket/path/"); - assertThat(catalog1.getWarehouse()).isEqualTo("s3://bucket/path"); + // ==================== Path Normalization Tests ==================== - LanceCatalog catalog2 = new LanceCatalog("test", "default", "s3://bucket/path///"); - assertThat(catalog2.getWarehouse()).isEqualTo("s3://bucket/path"); - } + @Test + @DisplayName("Test warehouse path normalization - remove trailing slashes") + void testWarehousePathNormalization() { + LanceCatalog catalog1 = new LanceCatalog("test", "default", "s3://bucket/path/"); + assertThat(catalog1.getWarehouse()).isEqualTo("s3://bucket/path"); - @Test - @DisplayName("Test warehouse path normalization - preserve root path") - void testWarehousePathNormalizationRoot() { - LanceCatalog catalog = new LanceCatalog("test", "default", "s3://bucket"); - assertThat(catalog.getWarehouse()).isEqualTo("s3://bucket"); - } + LanceCatalog catalog2 = new LanceCatalog("test", "default", "s3://bucket/path///"); + assertThat(catalog2.getWarehouse()).isEqualTo("s3://bucket/path"); + } - // ==================== Edge Case Tests ==================== + @Test + @DisplayName("Test warehouse path normalization - preserve root path") + void testWarehousePathNormalizationRoot() { + LanceCatalog catalog = new LanceCatalog("test", "default", "s3://bucket"); + assertThat(catalog.getWarehouse()).isEqualTo("s3://bucket"); + } - @Test - @DisplayName("Test S3 path with empty storage options") - void testS3PathWithEmptyOptions() { - LanceCatalog catalog = new LanceCatalog( - "test", "default", - "s3://bucket/path", - Collections.emptyMap()); + // ==================== Edge Case Tests ==================== - assertThat(catalog.isRemoteStorage()).isTrue(); - assertThat(catalog.getStorageOptions()).isEmpty(); - } + @Test + @DisplayName("Test S3 path with empty storage options") + void testS3PathWithEmptyOptions() { + LanceCatalog catalog = + new LanceCatalog("test", "default", "s3://bucket/path", Collections.emptyMap()); - @Test - @DisplayName("Test null storage options") - void testNullStorageOptions() { - LanceCatalog catalog = new LanceCatalog( - "test", "default", - "s3://bucket/path", - null); + assertThat(catalog.isRemoteStorage()).isTrue(); + assertThat(catalog.getStorageOptions()).isEmpty(); + } - assertThat(catalog.isRemoteStorage()).isTrue(); - assertThat(catalog.getStorageOptions()).isEmpty(); - } + @Test + @DisplayName("Test null storage options") + void testNullStorageOptions() { + LanceCatalog catalog = new LanceCatalog("test", "default", "s3://bucket/path", null); - @Test - @DisplayName("Test storage options immutability") - void testStorageOptionsImmutability() { - Map originalOptions = new HashMap<>(); - originalOptions.put("key", "value"); + assertThat(catalog.isRemoteStorage()).isTrue(); + assertThat(catalog.getStorageOptions()).isEmpty(); + } - LanceCatalog catalog = new LanceCatalog( - "test", "default", - "s3://bucket/path", - originalOptions); + @Test + @DisplayName("Test storage options immutability") + void testStorageOptionsImmutability() { + Map originalOptions = new HashMap<>(); + originalOptions.put("key", "value"); - // Modifying original map should not affect catalog internal options - originalOptions.put("new_key", "new_value"); + LanceCatalog catalog = + new LanceCatalog("test", "default", "s3://bucket/path", originalOptions); - assertThat(catalog.getStorageOptions()).doesNotContainKey("new_key"); - } + // Modifying original map should not affect catalog internal options + originalOptions.put("new_key", "new_value"); - @Test - @DisplayName("Test getStorageOptions returns unmodifiable Map") - void testGetStorageOptionsReturnsUnmodifiable() { - Map storageOptions = new HashMap<>(); - storageOptions.put("key", "value"); + assertThat(catalog.getStorageOptions()).doesNotContainKey("new_key"); + } - LanceCatalog catalog = new LanceCatalog( - "test", "default", - "s3://bucket/path", - storageOptions); + @Test + @DisplayName("Test getStorageOptions returns unmodifiable Map") + void testGetStorageOptionsReturnsUnmodifiable() { + Map storageOptions = new HashMap<>(); + storageOptions.put("key", "value"); - Map returnedOptions = catalog.getStorageOptions(); + LanceCatalog catalog = + new LanceCatalog("test", "default", "s3://bucket/path", storageOptions); - // Attempting to modify returned map should throw exception - assertThatThrownBy(() -> returnedOptions.put("new_key", "new_value")) - .isInstanceOf(UnsupportedOperationException.class); - } + Map returnedOptions = catalog.getStorageOptions(); - @Test - @DisplayName("Test S3 Catalog basic properties (no connection required)") - void testS3CatalogBasicProperties() { - Map storageOptions = new HashMap<>(); - storageOptions.put("aws_access_key_id", "test_key"); - storageOptions.put("aws_secret_access_key", "test_secret"); - storageOptions.put("aws_region", "us-east-1"); - - LanceCatalog catalog = new LanceCatalog( - "test_catalog", "default", - "s3://test-bucket/warehouse", - storageOptions); - - assertThat(catalog.getName()).isEqualTo("test_catalog"); - assertThat(catalog.getDefaultDatabase()).isEqualTo("default"); - assertThat(catalog.getWarehouse()).isEqualTo("s3://test-bucket/warehouse"); - assertThat(catalog.isRemoteStorage()).isTrue(); - assertThat(catalog.getStorageOptions()).containsEntry("aws_access_key_id", "test_key"); - } + // Attempting to modify returned map should throw exception + assertThatThrownBy(() -> returnedOptions.put("new_key", "new_value")) + .isInstanceOf(UnsupportedOperationException.class); } - // ==================== Integration Tests That Require MinIO ==================== - - /** - * Integration tests that require MinIO connection. - * Only run when MINIO_ENDPOINT environment variable is set and MinIO service is available. - */ - @Nested - @DisplayName("Integration Tests - MinIO Required") - @EnabledIf("org.apache.flink.connector.lance.table.LanceCatalogS3Test#isMinioAvailable") - class MinioIntegrationTests { - - private LanceCatalog s3Catalog; - private String warehousePath; - private String testId; - - @BeforeEach - void setUp() throws Exception { - // Generate unique path for each test to avoid interference between tests - testId = UUID.randomUUID().toString().substring(0, 8); - warehousePath = String.format("s3://%s/lance-warehouse-%s", testBucket, testId); - - // Create Catalog with S3 configuration - Map storageOptions = new HashMap<>(); - storageOptions.put("aws_access_key_id", minioAccessKey); - storageOptions.put("aws_secret_access_key", minioSecretKey); - storageOptions.put("aws_region", "us-east-1"); - storageOptions.put("aws_endpoint", minioEndpoint); - storageOptions.put("aws_virtual_hosted_style_request", "false"); - storageOptions.put("allow_http", "true"); - - s3Catalog = new LanceCatalog("lance_s3_catalog", "default", warehousePath, storageOptions); - s3Catalog.open(); - - LOG.info("Test Catalog created, warehouse: {}", warehousePath); - } + @Test + @DisplayName("Test S3 Catalog basic properties (no connection required)") + void testS3CatalogBasicProperties() { + Map storageOptions = new HashMap<>(); + storageOptions.put("aws_access_key_id", "test_key"); + storageOptions.put("aws_secret_access_key", "test_secret"); + storageOptions.put("aws_region", "us-east-1"); + + LanceCatalog catalog = + new LanceCatalog("test_catalog", "default", "s3://test-bucket/warehouse", storageOptions); + + assertThat(catalog.getName()).isEqualTo("test_catalog"); + assertThat(catalog.getDefaultDatabase()).isEqualTo("default"); + assertThat(catalog.getWarehouse()).isEqualTo("s3://test-bucket/warehouse"); + assertThat(catalog.isRemoteStorage()).isTrue(); + assertThat(catalog.getStorageOptions()).containsEntry("aws_access_key_id", "test_key"); + } + } + + // ==================== Integration Tests That Require MinIO ==================== + + /** + * Integration tests that require MinIO connection. Only run when MINIO_ENDPOINT environment + * variable is set and MinIO service is available. + */ + @Nested + @DisplayName("Integration Tests - MinIO Required") + @EnabledIf("org.apache.flink.connector.lance.table.LanceCatalogS3Test#isMinioAvailable") + class MinioIntegrationTests { + + private LanceCatalog s3Catalog; + private String warehousePath; + private String testId; + + @BeforeEach + void setUp() throws Exception { + // Generate unique path for each test to avoid interference between tests + testId = UUID.randomUUID().toString().substring(0, 8); + warehousePath = String.format("s3://%s/lance-warehouse-%s", testBucket, testId); + + // Create Catalog with S3 configuration + Map storageOptions = new HashMap<>(); + storageOptions.put("aws_access_key_id", minioAccessKey); + storageOptions.put("aws_secret_access_key", minioSecretKey); + storageOptions.put("aws_region", "us-east-1"); + storageOptions.put("aws_endpoint", minioEndpoint); + storageOptions.put("aws_virtual_hosted_style_request", "false"); + storageOptions.put("allow_http", "true"); + + s3Catalog = new LanceCatalog("lance_s3_catalog", "default", warehousePath, storageOptions); + s3Catalog.open(); + + LOG.info("Test Catalog created, warehouse: {}", warehousePath); + } - @AfterEach - void tearDown() throws Exception { - if (s3Catalog != null) { - s3Catalog.close(); - } - } + @AfterEach + void tearDown() throws Exception { + if (s3Catalog != null) { + s3Catalog.close(); + } + } - // ==================== Basic Properties Tests ==================== + // ==================== Basic Properties Tests ==================== - @Test - @DisplayName("Test S3 Catalog basic properties") - void testS3CatalogProperties() { - assertThat(s3Catalog.getName()).isEqualTo("lance_s3_catalog"); - assertThat(s3Catalog.getDefaultDatabase()).isEqualTo("default"); - assertThat(s3Catalog.getWarehouse()).isEqualTo(warehousePath); - assertThat(s3Catalog.isRemoteStorage()).isTrue(); - } + @Test + @DisplayName("Test S3 Catalog basic properties") + void testS3CatalogProperties() { + assertThat(s3Catalog.getName()).isEqualTo("lance_s3_catalog"); + assertThat(s3Catalog.getDefaultDatabase()).isEqualTo("default"); + assertThat(s3Catalog.getWarehouse()).isEqualTo(warehousePath); + assertThat(s3Catalog.isRemoteStorage()).isTrue(); + } - @Test - @DisplayName("Test S3 storage options configuration") - void testS3StorageOptions() { - Map options = s3Catalog.getStorageOptions(); + @Test + @DisplayName("Test S3 storage options configuration") + void testS3StorageOptions() { + Map options = s3Catalog.getStorageOptions(); - assertThat(options).containsEntry("aws_access_key_id", minioAccessKey); - assertThat(options).containsEntry("aws_secret_access_key", minioSecretKey); - assertThat(options).containsEntry("aws_region", "us-east-1"); - assertThat(options).containsEntry("aws_endpoint", minioEndpoint); - assertThat(options).containsEntry("allow_http", "true"); - } + assertThat(options).containsEntry("aws_access_key_id", minioAccessKey); + assertThat(options).containsEntry("aws_secret_access_key", minioSecretKey); + assertThat(options).containsEntry("aws_region", "us-east-1"); + assertThat(options).containsEntry("aws_endpoint", minioEndpoint); + assertThat(options).containsEntry("allow_http", "true"); + } - // ==================== Database Operation Tests ==================== + // ==================== Database Operation Tests ==================== - @Test - @DisplayName("Test S3 Catalog default database exists") - void testDefaultDatabaseExists() throws Exception { - assertThat(s3Catalog.databaseExists("default")).isTrue(); - } + @Test + @DisplayName("Test S3 Catalog default database exists") + void testDefaultDatabaseExists() throws Exception { + assertThat(s3Catalog.databaseExists("default")).isTrue(); + } - @Test - @DisplayName("Test S3 Catalog list databases") - void testListDatabases() throws Exception { - List databases = s3Catalog.listDatabases(); - assertThat(databases).contains("default"); - } + @Test + @DisplayName("Test S3 Catalog list databases") + void testListDatabases() throws Exception { + List databases = s3Catalog.listDatabases(); + assertThat(databases).contains("default"); + } - @Test - @DisplayName("Test S3 Catalog create database") - void testCreateDatabase() throws Exception { - String dbName = "test_s3_db_" + testId; + @Test + @DisplayName("Test S3 Catalog create database") + void testCreateDatabase() throws Exception { + String dbName = "test_s3_db_" + testId; - // Create database - s3Catalog.createDatabase(dbName, null, false); + // Create database + s3Catalog.createDatabase(dbName, null, false); - // Verify database exists - assertThat(s3Catalog.databaseExists(dbName)).isTrue(); + // Verify database exists + assertThat(s3Catalog.databaseExists(dbName)).isTrue(); - // Verify database in list - List databases = s3Catalog.listDatabases(); - assertThat(databases).contains(dbName); - } + // Verify database in list + List databases = s3Catalog.listDatabases(); + assertThat(databases).contains(dbName); + } - @Test - @DisplayName("Test S3 Catalog create existing database (ignoreIfExists=false)") - void testCreateExistingDatabaseWithoutIgnore() throws Exception { - String dbName = "existing_db_" + testId; - s3Catalog.createDatabase(dbName, null, false); + @Test + @DisplayName("Test S3 Catalog create existing database (ignoreIfExists=false)") + void testCreateExistingDatabaseWithoutIgnore() throws Exception { + String dbName = "existing_db_" + testId; + s3Catalog.createDatabase(dbName, null, false); - // Creating again should throw exception - assertThatThrownBy(() -> s3Catalog.createDatabase(dbName, null, false)) - .isInstanceOf(DatabaseAlreadyExistException.class); - } + // Creating again should throw exception + assertThatThrownBy(() -> s3Catalog.createDatabase(dbName, null, false)) + .isInstanceOf(DatabaseAlreadyExistException.class); + } - @Test - @DisplayName("Test S3 Catalog create existing database (ignoreIfExists=true)") - void testCreateExistingDatabaseWithIgnore() throws Exception { - String dbName = "existing_db_2_" + testId; - s3Catalog.createDatabase(dbName, null, false); + @Test + @DisplayName("Test S3 Catalog create existing database (ignoreIfExists=true)") + void testCreateExistingDatabaseWithIgnore() throws Exception { + String dbName = "existing_db_2_" + testId; + s3Catalog.createDatabase(dbName, null, false); - // Creating again should not throw exception - s3Catalog.createDatabase(dbName, null, true); + // Creating again should not throw exception + s3Catalog.createDatabase(dbName, null, true); - assertThat(s3Catalog.databaseExists(dbName)).isTrue(); - } + assertThat(s3Catalog.databaseExists(dbName)).isTrue(); + } - @Test - @DisplayName("Test S3 Catalog get database") - void testGetDatabase() throws Exception { - String dbName = "get_db_test_" + testId; - s3Catalog.createDatabase(dbName, null, false); + @Test + @DisplayName("Test S3 Catalog get database") + void testGetDatabase() throws Exception { + String dbName = "get_db_test_" + testId; + s3Catalog.createDatabase(dbName, null, false); - CatalogDatabase database = s3Catalog.getDatabase(dbName); - assertThat(database).isNotNull(); - assertThat(database.getComment()).contains("Lance Database"); - } + CatalogDatabase database = s3Catalog.getDatabase(dbName); + assertThat(database).isNotNull(); + assertThat(database.getComment()).contains("Lance Database"); + } - @Test - @DisplayName("Test S3 Catalog get non-existing database") - void testGetNonExistingDatabase() { - assertThatThrownBy(() -> s3Catalog.getDatabase("non_existing_db_" + testId)) - .isInstanceOf(DatabaseNotExistException.class); - } + @Test + @DisplayName("Test S3 Catalog get non-existing database") + void testGetNonExistingDatabase() { + assertThatThrownBy(() -> s3Catalog.getDatabase("non_existing_db_" + testId)) + .isInstanceOf(DatabaseNotExistException.class); + } - @Test - @DisplayName("Test S3 Catalog drop database") - void testDropDatabase() throws Exception { - String dbName = "drop_db_test_" + testId; - s3Catalog.createDatabase(dbName, null, false); - assertThat(s3Catalog.databaseExists(dbName)).isTrue(); + @Test + @DisplayName("Test S3 Catalog drop database") + void testDropDatabase() throws Exception { + String dbName = "drop_db_test_" + testId; + s3Catalog.createDatabase(dbName, null, false); + assertThat(s3Catalog.databaseExists(dbName)).isTrue(); - // Drop database - s3Catalog.dropDatabase(dbName, false, false); + // Drop database + s3Catalog.dropDatabase(dbName, false, false); - // Verify database not in list - List databases = s3Catalog.listDatabases(); - assertThat(databases).doesNotContain(dbName); - } + // Verify database not in list + List databases = s3Catalog.listDatabases(); + assertThat(databases).doesNotContain(dbName); + } - @Test - @DisplayName("Test S3 Catalog drop non-existing database (ignoreIfNotExists=false)") - void testDropNonExistingDatabaseWithoutIgnore() { - assertThatThrownBy(() -> s3Catalog.dropDatabase("non_existing_drop_db_" + testId, false, false)) - .isInstanceOf(DatabaseNotExistException.class); - } + @Test + @DisplayName("Test S3 Catalog drop non-existing database (ignoreIfNotExists=false)") + void testDropNonExistingDatabaseWithoutIgnore() { + assertThatThrownBy( + () -> s3Catalog.dropDatabase("non_existing_drop_db_" + testId, false, false)) + .isInstanceOf(DatabaseNotExistException.class); + } - @Test - @DisplayName("Test S3 Catalog drop non-existing database (ignoreIfNotExists=true)") - void testDropNonExistingDatabaseWithIgnore() throws Exception { - // Should not throw exception - s3Catalog.dropDatabase("non_existing_drop_db_2_" + testId, true, false); - } + @Test + @DisplayName("Test S3 Catalog drop non-existing database (ignoreIfNotExists=true)") + void testDropNonExistingDatabaseWithIgnore() throws Exception { + // Should not throw exception + s3Catalog.dropDatabase("non_existing_drop_db_2_" + testId, true, false); + } - // ==================== Table Operation Tests ==================== + // ==================== Table Operation Tests ==================== - @Test - @DisplayName("Test S3 Catalog list tables (empty database)") - void testListTablesEmpty() throws Exception { - String dbName = "empty_tables_db_" + testId; - s3Catalog.createDatabase(dbName, null, false); + @Test + @DisplayName("Test S3 Catalog list tables (empty database)") + void testListTablesEmpty() throws Exception { + String dbName = "empty_tables_db_" + testId; + s3Catalog.createDatabase(dbName, null, false); - List tables = s3Catalog.listTables(dbName); - assertThat(tables).isEmpty(); - } + List tables = s3Catalog.listTables(dbName); + assertThat(tables).isEmpty(); + } - @Test - @DisplayName("Test S3 Catalog table not exists") - void testTableNotExists() throws Exception { - String dbName = "table_check_db_" + testId; - s3Catalog.createDatabase(dbName, null, false); + @Test + @DisplayName("Test S3 Catalog table not exists") + void testTableNotExists() throws Exception { + String dbName = "table_check_db_" + testId; + s3Catalog.createDatabase(dbName, null, false); - assertThat(s3Catalog.tableExists(new org.apache.flink.table.catalog.ObjectPath(dbName, "non_existing_table"))) - .isFalse(); - } + assertThat( + s3Catalog.tableExists( + new org.apache.flink.table.catalog.ObjectPath(dbName, "non_existing_table"))) + .isFalse(); + } - // ==================== SQL DDL Create Catalog Tests ==================== - - @Test - @DisplayName("Test creating S3 Catalog via SQL DDL") - void testCreateS3CatalogViaSql() throws Exception { - EnvironmentSettings settings = EnvironmentSettings.newInstance() - .inBatchMode() - .build(); - TableEnvironment tableEnv = TableEnvironment.create(settings); - - String catalogName = "lance_s3_sql_" + testId; - - // Create S3 Catalog using SQL - String createCatalogSql = String.format( - "CREATE CATALOG %s WITH (" + - "'type' = 'lance', " + - "'warehouse' = '%s', " + - "'default-database' = 'default', " + - "'s3-access-key' = '%s', " + - "'s3-secret-key' = '%s', " + - "'s3-region' = 'us-east-1', " + - "'s3-endpoint' = '%s', " + - "'s3-allow-http' = 'true', " + - "'s3-virtual-hosted-style' = 'false'" + - ")", - catalogName, warehousePath, minioAccessKey, minioSecretKey, minioEndpoint); - - tableEnv.executeSql(createCatalogSql); - - // Verify Catalog was created - String[] catalogs = tableEnv.listCatalogs(); - assertThat(catalogs).contains(catalogName); - - // Use Catalog - tableEnv.useCatalog(catalogName); - assertThat(tableEnv.getCurrentCatalog()).isEqualTo(catalogName); - - // Verify default database - assertThat(tableEnv.getCurrentDatabase()).isEqualTo("default"); - } + // ==================== SQL DDL Create Catalog Tests ==================== + + @Test + @DisplayName("Test creating S3 Catalog via SQL DDL") + void testCreateS3CatalogViaSql() throws Exception { + EnvironmentSettings settings = EnvironmentSettings.newInstance().inBatchMode().build(); + TableEnvironment tableEnv = TableEnvironment.create(settings); + + String catalogName = "lance_s3_sql_" + testId; + + // Create S3 Catalog using SQL + String createCatalogSql = + String.format( + "CREATE CATALOG %s WITH (" + + "'type' = 'lance', " + + "'warehouse' = '%s', " + + "'default-database' = 'default', " + + "'s3-access-key' = '%s', " + + "'s3-secret-key' = '%s', " + + "'s3-region' = 'us-east-1', " + + "'s3-endpoint' = '%s', " + + "'s3-allow-http' = 'true', " + + "'s3-virtual-hosted-style' = 'false'" + + ")", + catalogName, warehousePath, minioAccessKey, minioSecretKey, minioEndpoint); + + tableEnv.executeSql(createCatalogSql); + + // Verify Catalog was created + String[] catalogs = tableEnv.listCatalogs(); + assertThat(catalogs).contains(catalogName); + + // Use Catalog + tableEnv.useCatalog(catalogName); + assertThat(tableEnv.getCurrentCatalog()).isEqualTo(catalogName); + + // Verify default database + assertThat(tableEnv.getCurrentDatabase()).isEqualTo("default"); + } - @Test - @DisplayName("Test creating database in S3 Catalog via SQL DDL") - void testCreateDatabaseViaSql() throws Exception { - EnvironmentSettings settings = EnvironmentSettings.newInstance() - .inBatchMode() - .build(); - TableEnvironment tableEnv = TableEnvironment.create(settings); - - String catalogName = "lance_s3_db_sql_" + testId; - - // Create S3 Catalog - String createCatalogSql = String.format( - "CREATE CATALOG %s WITH (" + - "'type' = 'lance', " + - "'warehouse' = '%s', " + - "'s3-access-key' = '%s', " + - "'s3-secret-key' = '%s', " + - "'s3-region' = 'us-east-1', " + - "'s3-endpoint' = '%s', " + - "'s3-allow-http' = 'true', " + - "'s3-virtual-hosted-style' = 'false'" + - ")", - catalogName, warehousePath, minioAccessKey, minioSecretKey, minioEndpoint); - - tableEnv.executeSql(createCatalogSql); - tableEnv.useCatalog(catalogName); - - // Create database - String dbName = "test_database_" + testId; - tableEnv.executeSql("CREATE DATABASE IF NOT EXISTS " + dbName); - - // Verify database was created - String[] databases = tableEnv.listDatabases(); - assertThat(databases).contains(dbName); - } + @Test + @DisplayName("Test creating database in S3 Catalog via SQL DDL") + void testCreateDatabaseViaSql() throws Exception { + EnvironmentSettings settings = EnvironmentSettings.newInstance().inBatchMode().build(); + TableEnvironment tableEnv = TableEnvironment.create(settings); + + String catalogName = "lance_s3_db_sql_" + testId; + + // Create S3 Catalog + String createCatalogSql = + String.format( + "CREATE CATALOG %s WITH (" + + "'type' = 'lance', " + + "'warehouse' = '%s', " + + "'s3-access-key' = '%s', " + + "'s3-secret-key' = '%s', " + + "'s3-region' = 'us-east-1', " + + "'s3-endpoint' = '%s', " + + "'s3-allow-http' = 'true', " + + "'s3-virtual-hosted-style' = 'false'" + + ")", + catalogName, warehousePath, minioAccessKey, minioSecretKey, minioEndpoint); + + tableEnv.executeSql(createCatalogSql); + tableEnv.useCatalog(catalogName); + + // Create database + String dbName = "test_database_" + testId; + tableEnv.executeSql("CREATE DATABASE IF NOT EXISTS " + dbName); + + // Verify database was created + String[] databases = tableEnv.listDatabases(); + assertThat(databases).contains(dbName); + } - // ==================== Multiple Catalog Tests ==================== - - @Test - @DisplayName("Test multiple S3 Catalog instances") - void testMultipleS3Catalogs() throws Exception { - Map storageOptions = new HashMap<>(); - storageOptions.put("aws_access_key_id", minioAccessKey); - storageOptions.put("aws_secret_access_key", minioSecretKey); - storageOptions.put("aws_region", "us-east-1"); - storageOptions.put("aws_endpoint", minioEndpoint); - storageOptions.put("allow_http", "true"); - - // Create first Catalog - LanceCatalog catalog1 = new LanceCatalog( - "catalog1", "default", - "s3://" + testBucket + "/warehouse1_" + testId, - storageOptions); - catalog1.open(); - - // Create second Catalog - LanceCatalog catalog2 = new LanceCatalog( - "catalog2", "default", - "s3://" + testBucket + "/warehouse2_" + testId, - storageOptions); - catalog2.open(); - - try { - // Verify two Catalogs work independently - assertThat(catalog1.getWarehouse()).isNotEqualTo(catalog2.getWarehouse()); - - String db1 = "db1_" + testId; - String db2 = "db2_" + testId; - - catalog1.createDatabase(db1, null, false); - catalog2.createDatabase(db2, null, false); - - assertThat(catalog1.listDatabases()).contains(db1); - assertThat(catalog2.listDatabases()).contains(db2); - assertThat(catalog1.listDatabases()).doesNotContain(db2); - assertThat(catalog2.listDatabases()).doesNotContain(db1); - } finally { - catalog1.close(); - catalog2.close(); - } - } + // ==================== Multiple Catalog Tests ==================== + + @Test + @DisplayName("Test multiple S3 Catalog instances") + void testMultipleS3Catalogs() throws Exception { + Map storageOptions = new HashMap<>(); + storageOptions.put("aws_access_key_id", minioAccessKey); + storageOptions.put("aws_secret_access_key", minioSecretKey); + storageOptions.put("aws_region", "us-east-1"); + storageOptions.put("aws_endpoint", minioEndpoint); + storageOptions.put("allow_http", "true"); + + // Create first Catalog + LanceCatalog catalog1 = + new LanceCatalog( + "catalog1", + "default", + "s3://" + testBucket + "/warehouse1_" + testId, + storageOptions); + catalog1.open(); + + // Create second Catalog + LanceCatalog catalog2 = + new LanceCatalog( + "catalog2", + "default", + "s3://" + testBucket + "/warehouse2_" + testId, + storageOptions); + catalog2.open(); + + try { + // Verify two Catalogs work independently + assertThat(catalog1.getWarehouse()).isNotEqualTo(catalog2.getWarehouse()); + + String db1 = "db1_" + testId; + String db2 = "db2_" + testId; + + catalog1.createDatabase(db1, null, false); + catalog2.createDatabase(db2, null, false); + + assertThat(catalog1.listDatabases()).contains(db1); + assertThat(catalog2.listDatabases()).contains(db2); + assertThat(catalog1.listDatabases()).doesNotContain(db2); + assertThat(catalog2.listDatabases()).doesNotContain(db1); + } finally { + catalog1.close(); + catalog2.close(); + } } + } } diff --git a/src/test/java/org/apache/flink/connector/lance/table/LanceReadOptimizationsTest.java b/src/test/java/org/apache/flink/connector/lance/table/LanceReadOptimizationsTest.java index 9b3a5f0..2a6c7e2 100644 --- a/src/test/java/org/apache/flink/connector/lance/table/LanceReadOptimizationsTest.java +++ b/src/test/java/org/apache/flink/connector/lance/table/LanceReadOptimizationsTest.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,18 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.table; import org.apache.flink.connector.lance.config.LanceOptions; import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown; import org.apache.flink.table.expressions.CallExpression; import org.apache.flink.table.expressions.FieldReferenceExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.expressions.ValueLiteralExpression; -import org.apache.flink.table.functions.BuiltInFunctionDefinitions; -import org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown; import org.apache.flink.table.functions.BuiltInFunctionDefinition; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; import org.apache.flink.table.types.DataType; import org.junit.jupiter.api.BeforeEach; @@ -45,484 +40,477 @@ /** * Read optimization tests - * + * *

Test contents: + * *

    - *
  • Limit push-down
  • - *
  • Predicate push-down (basic comparison, IN, BETWEEN)
  • - *
  • Column pruning
  • + *
  • Limit push-down + *
  • Predicate push-down (basic comparison, IN, BETWEEN) + *
  • Column pruning *
*/ @DisplayName("Read Optimization Tests") public class LanceReadOptimizationsTest { - @TempDir - File tempDir; - - private LanceOptions baseOptions; - private DataType physicalDataType; - - @BeforeEach - void setUp() { - baseOptions = LanceOptions.builder() - .path(tempDir.getAbsolutePath() + "/test_dataset") - .readBatchSize(100) - .build(); - - // Define test table schema - physicalDataType = DataTypes.ROW( - DataTypes.FIELD("id", DataTypes.BIGINT()), - DataTypes.FIELD("name", DataTypes.STRING()), - DataTypes.FIELD("status", DataTypes.STRING()), - DataTypes.FIELD("score", DataTypes.DOUBLE()), - DataTypes.FIELD("created_time", DataTypes.STRING()) - ); + @TempDir File tempDir; + + private LanceOptions baseOptions; + private DataType physicalDataType; + + @BeforeEach + void setUp() { + baseOptions = + LanceOptions.builder() + .path(tempDir.getAbsolutePath() + "/test_dataset") + .readBatchSize(100) + .build(); + + // Define test table schema + physicalDataType = + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.BIGINT()), + DataTypes.FIELD("name", DataTypes.STRING()), + DataTypes.FIELD("status", DataTypes.STRING()), + DataTypes.FIELD("score", DataTypes.DOUBLE()), + DataTypes.FIELD("created_time", DataTypes.STRING())); + } + + // ==================== Limit Push-Down Tests ==================== + + @Nested + @DisplayName("Limit Push-Down Tests") + class LimitPushDownTests { + + @Test + @DisplayName("Test applyLimit method") + void testApplyLimit() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + // Initial state should have no limit + assertNull(source.getLimit(), "Initial limit should be null"); + + // Apply limit + source.applyLimit(100); + + // Verify limit is set + assertEquals(100L, source.getLimit(), "Limit should be correctly set to 100"); + } + + @Test + @DisplayName("Test Limit of 0") + void testZeroLimit() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + source.applyLimit(0); + assertEquals(0L, source.getLimit(), "Limit should be settable to 0"); } - // ==================== Limit Push-Down Tests ==================== - - @Nested - @DisplayName("Limit Push-Down Tests") - class LimitPushDownTests { - - @Test - @DisplayName("Test applyLimit method") - void testApplyLimit() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - - // Initial state should have no limit - assertNull(source.getLimit(), "Initial limit should be null"); - - // Apply limit - source.applyLimit(100); - - // Verify limit is set - assertEquals(100L, source.getLimit(), "Limit should be correctly set to 100"); - } - - @Test - @DisplayName("Test Limit of 0") - void testZeroLimit() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - source.applyLimit(0); - assertEquals(0L, source.getLimit(), "Limit should be settable to 0"); - } - - @Test - @DisplayName("Test large Limit value") - void testLargeLimit() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - long largeLimit = Long.MAX_VALUE; - source.applyLimit(largeLimit); - assertEquals(largeLimit, source.getLimit(), "Should support large Limit values"); - } - - @Test - @DisplayName("Test copy preserves Limit") - void testCopyPreservesLimit() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - source.applyLimit(50); - - LanceDynamicTableSource copied = (LanceDynamicTableSource) source.copy(); - - assertEquals(50L, copied.getLimit(), "copy() should preserve limit value"); - } + @Test + @DisplayName("Test large Limit value") + void testLargeLimit() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + long largeLimit = Long.MAX_VALUE; + source.applyLimit(largeLimit); + assertEquals(largeLimit, source.getLimit(), "Should support large Limit values"); } - // ==================== Predicate Push-Down Tests ==================== - - @Nested - @DisplayName("Predicate Push-Down Tests") - class FilterPushDownTests { - - @Test - @DisplayName("Test equals comparison push-down") - void testEqualsFilterPushDown() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - - // Create status = 'active' expression - List filters = createEqualsFilter("status", "active"); - - SupportsFilterPushDown.Result result = source.applyFilters(filters); - - // Verify filter is accepted - assertEquals(1, result.getAcceptedFilters().size(), "Equals comparison should be accepted"); - assertEquals(0, result.getRemainingFilters().size(), "Should not have remaining filters"); - } - - @Test - @DisplayName("Test numeric comparison push-down") - void testNumericComparisonPushDown() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - - // Create score > 80 expression - List filters = createComparisonFilter("score", 80.0, BuiltInFunctionDefinitions.GREATER_THAN); - - SupportsFilterPushDown.Result result = source.applyFilters(filters); - - assertEquals(1, result.getAcceptedFilters().size(), "Numeric comparison should be accepted"); - } - - @Test - @DisplayName("Test AND logic push-down") - void testAndLogicPushDown() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - - // Create status = 'active' AND score > 60 expression - ResolvedExpression statusFilter = createEqualsExpression("status", "active"); - ResolvedExpression scoreFilter = createComparisonExpression("score", 60.0, BuiltInFunctionDefinitions.GREATER_THAN); - - CallExpression andExpr = CallExpression.permanent( - BuiltInFunctionDefinitions.AND, - Arrays.asList(statusFilter, scoreFilter), - DataTypes.BOOLEAN() - ); - - SupportsFilterPushDown.Result result = source.applyFilters(Collections.singletonList(andExpr)); - - assertEquals(1, result.getAcceptedFilters().size(), "AND logic should be accepted"); - } - - @Test - @DisplayName("Test IS NULL push-down") - void testIsNullPushDown() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - - // Create name IS NULL expression - FieldReferenceExpression fieldRef = new FieldReferenceExpression( - "name", DataTypes.STRING(), 0, 1); - - CallExpression isNullExpr = CallExpression.permanent( - BuiltInFunctionDefinitions.IS_NULL, - Collections.singletonList(fieldRef), - DataTypes.BOOLEAN() - ); - - SupportsFilterPushDown.Result result = source.applyFilters(Collections.singletonList(isNullExpr)); - - assertEquals(1, result.getAcceptedFilters().size(), "IS NULL should be accepted"); - } - - @Test - @DisplayName("Test IS NOT NULL push-down") - void testIsNotNullPushDown() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - - // Create name IS NOT NULL expression - FieldReferenceExpression fieldRef = new FieldReferenceExpression( - "name", DataTypes.STRING(), 0, 1); - - CallExpression isNotNullExpr = CallExpression.permanent( - BuiltInFunctionDefinitions.IS_NOT_NULL, - Collections.singletonList(fieldRef), - DataTypes.BOOLEAN() - ); - - SupportsFilterPushDown.Result result = source.applyFilters(Collections.singletonList(isNotNullExpr)); - - assertEquals(1, result.getAcceptedFilters().size(), "IS NOT NULL should be accepted"); - } - - @Test - @DisplayName("Test LIKE push-down") - void testLikePushDown() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - - // Create name LIKE 'test%' expression - FieldReferenceExpression fieldRef = new FieldReferenceExpression( - "name", DataTypes.STRING(), 0, 1); - ValueLiteralExpression pattern = new ValueLiteralExpression("test%"); - - CallExpression likeExpr = CallExpression.permanent( - BuiltInFunctionDefinitions.LIKE, - Arrays.asList(fieldRef, pattern), - DataTypes.BOOLEAN() - ); - - SupportsFilterPushDown.Result result = source.applyFilters(Collections.singletonList(likeExpr)); - - assertEquals(1, result.getAcceptedFilters().size(), "LIKE should be accepted"); - } - - @Test - @DisplayName("Test IN predicate push-down") - void testInPredicatePushDown() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - - // Create status IN ('active', 'pending', 'completed') expression - FieldReferenceExpression fieldRef = new FieldReferenceExpression( - "status", DataTypes.STRING(), 0, 2); - ValueLiteralExpression value1 = new ValueLiteralExpression("active"); - ValueLiteralExpression value2 = new ValueLiteralExpression("pending"); - ValueLiteralExpression value3 = new ValueLiteralExpression("completed"); - - CallExpression inExpr = CallExpression.permanent( - BuiltInFunctionDefinitions.IN, - Arrays.asList(fieldRef, value1, value2, value3), - DataTypes.BOOLEAN() - ); - - SupportsFilterPushDown.Result result = source.applyFilters(Collections.singletonList(inExpr)); - - assertEquals(1, result.getAcceptedFilters().size(), "IN predicate should be accepted"); - } - - @Test - @DisplayName("Test multiple independent filter conditions") - void testMultipleFilters() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - - // Create multiple independent filter conditions - List filter1 = createEqualsFilter("status", "active"); - List filter2 = createComparisonFilter("score", 60.0, BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL); - - List allFilters = new ArrayList<>(); - allFilters.addAll(filter1); - allFilters.addAll(filter2); - - SupportsFilterPushDown.Result result = source.applyFilters(allFilters); - - assertEquals(2, result.getAcceptedFilters().size(), "Two filter conditions should be accepted"); - assertEquals(0, result.getRemainingFilters().size(), "Should not have remaining filters"); - } - - @Test - @DisplayName("Test copy preserves filter conditions") - void testCopyPreservesFilters() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - - // Apply filter conditions - List filters = createEqualsFilter("status", "active"); - source.applyFilters(filters); - - LanceDynamicTableSource copied = (LanceDynamicTableSource) source.copy(); - - // Verify copied source preserves filter conditions - assertNotNull(copied, "copy() should succeed"); - } + @Test + @DisplayName("Test copy preserves Limit") + void testCopyPreservesLimit() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + source.applyLimit(50); + + LanceDynamicTableSource copied = (LanceDynamicTableSource) source.copy(); + + assertEquals(50L, copied.getLimit(), "copy() should preserve limit value"); } + } - // ==================== Column Pruning Tests ==================== + // ==================== Predicate Push-Down Tests ==================== - @Nested - @DisplayName("Column Pruning Tests") - class ProjectionPushDownTests { + @Nested + @DisplayName("Predicate Push-Down Tests") + class FilterPushDownTests { - @Test - @DisplayName("Test single column projection") - void testSingleColumnProjection() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + @Test + @DisplayName("Test equals comparison push-down") + void testEqualsFilterPushDown() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - // Select only id column - int[][] projection = {{0}}; // First column - source.applyProjection(projection); + // Create status = 'active' expression + List filters = createEqualsFilter("status", "active"); - // Verify projection is applied - assertNotNull(source, "Projection should be successfully applied"); - } + SupportsFilterPushDown.Result result = source.applyFilters(filters); - @Test - @DisplayName("Test multiple column projection") - void testMultipleColumnProjection() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + // Verify filter is accepted + assertEquals(1, result.getAcceptedFilters().size(), "Equals comparison should be accepted"); + assertEquals(0, result.getRemainingFilters().size(), "Should not have remaining filters"); + } + + @Test + @DisplayName("Test numeric comparison push-down") + void testNumericComparisonPushDown() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - // Select id, name, score columns - int[][] projection = {{0}, {1}, {3}}; - source.applyProjection(projection); + // Create score > 80 expression + List filters = + createComparisonFilter("score", 80.0, BuiltInFunctionDefinitions.GREATER_THAN); - assertNotNull(source, "Multiple column projection should be successfully applied"); - } + SupportsFilterPushDown.Result result = source.applyFilters(filters); - @Test - @DisplayName("Test nested projection not supported") - void testNestedProjectionNotSupported() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + assertEquals(1, result.getAcceptedFilters().size(), "Numeric comparison should be accepted"); + } - assertFalse(source.supportsNestedProjection(), "Should not support nested projection"); - } + @Test + @DisplayName("Test AND logic push-down") + void testAndLogicPushDown() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - @Test - @DisplayName("Test copy preserves projection") - void testCopyPreservesProjection() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + // Create status = 'active' AND score > 60 expression + ResolvedExpression statusFilter = createEqualsExpression("status", "active"); + ResolvedExpression scoreFilter = + createComparisonExpression("score", 60.0, BuiltInFunctionDefinitions.GREATER_THAN); - int[][] projection = {{0}, {2}}; - source.applyProjection(projection); + CallExpression andExpr = + CallExpression.permanent( + BuiltInFunctionDefinitions.AND, + Arrays.asList(statusFilter, scoreFilter), + DataTypes.BOOLEAN()); - LanceDynamicTableSource copied = (LanceDynamicTableSource) source.copy(); + SupportsFilterPushDown.Result result = + source.applyFilters(Collections.singletonList(andExpr)); - assertNotNull(copied, "copy() should preserve projection information"); - } + assertEquals(1, result.getAcceptedFilters().size(), "AND logic should be accepted"); } - // ==================== Combined Tests ==================== + @Test + @DisplayName("Test IS NULL push-down") + void testIsNullPushDown() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - @Nested - @DisplayName("Combined Optimization Tests") - class CombinedOptimizationsTests { + // Create name IS NULL expression + FieldReferenceExpression fieldRef = + new FieldReferenceExpression("name", DataTypes.STRING(), 0, 1); - @Test - @DisplayName("Test Limit + filter condition combination") - void testLimitWithFilter() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + CallExpression isNullExpr = + CallExpression.permanent( + BuiltInFunctionDefinitions.IS_NULL, + Collections.singletonList(fieldRef), + DataTypes.BOOLEAN()); - // Apply filter condition - List filters = createEqualsFilter("status", "active"); - source.applyFilters(filters); + SupportsFilterPushDown.Result result = + source.applyFilters(Collections.singletonList(isNullExpr)); - // Apply limit - source.applyLimit(100L); + assertEquals(1, result.getAcceptedFilters().size(), "IS NULL should be accepted"); + } - assertEquals(Long.valueOf(100L), source.getLimit(), "Limit should be correctly set"); - } + @Test + @DisplayName("Test IS NOT NULL push-down") + void testIsNotNullPushDown() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - @Test - @DisplayName("Test Limit + projection combination") - void testLimitWithProjection() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + // Create name IS NOT NULL expression + FieldReferenceExpression fieldRef = + new FieldReferenceExpression("name", DataTypes.STRING(), 0, 1); - // Apply projection - int[][] projection = {{0}, {1}}; - source.applyProjection(projection); + CallExpression isNotNullExpr = + CallExpression.permanent( + BuiltInFunctionDefinitions.IS_NOT_NULL, + Collections.singletonList(fieldRef), + DataTypes.BOOLEAN()); - // Apply limit - source.applyLimit(50L); + SupportsFilterPushDown.Result result = + source.applyFilters(Collections.singletonList(isNotNullExpr)); - assertEquals(Long.valueOf(50L), source.getLimit(), "Limit should be correctly set"); - } + assertEquals(1, result.getAcceptedFilters().size(), "IS NOT NULL should be accepted"); + } - @Test - @DisplayName("Test all optimizations combined") - void testAllOptimizations() { - LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + @Test + @DisplayName("Test LIKE push-down") + void testLikePushDown() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); - // 1. Apply projection - int[][] projection = {{0}, {1}, {3}}; // id, name, score - source.applyProjection(projection); + // Create name LIKE 'test%' expression + FieldReferenceExpression fieldRef = + new FieldReferenceExpression("name", DataTypes.STRING(), 0, 1); + ValueLiteralExpression pattern = new ValueLiteralExpression("test%"); - // 2. Apply filter condition - List filters = createComparisonFilter("score", 60.0, BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL); - SupportsFilterPushDown.Result result = source.applyFilters(filters); + CallExpression likeExpr = + CallExpression.permanent( + BuiltInFunctionDefinitions.LIKE, + Arrays.asList(fieldRef, pattern), + DataTypes.BOOLEAN()); - // 3. Apply limit - source.applyLimit(100L); + SupportsFilterPushDown.Result result = + source.applyFilters(Collections.singletonList(likeExpr)); - // Verify all optimizations are correctly applied - assertEquals(1, result.getAcceptedFilters().size(), "Filter condition should be accepted"); - assertEquals(Long.valueOf(100L), source.getLimit(), "Limit should be correctly set"); - } + assertEquals(1, result.getAcceptedFilters().size(), "LIKE should be accepted"); } - // ==================== LanceOptions Tests ==================== - - @Nested - @DisplayName("LanceOptions Limit Configuration Tests") - class LanceOptionsLimitTests { - - @Test - @DisplayName("Test readLimit configuration") - void testReadLimitConfig() { - LanceOptions options = LanceOptions.builder() - .path("/test/path") - .readLimit(500L) - .build(); - - assertEquals(500L, options.getReadLimit(), "readLimit should be correctly configured"); - } - - @Test - @DisplayName("Test readLimit default value") - void testReadLimitDefault() { - LanceOptions options = LanceOptions.builder() - .path("/test/path") - .build(); - - assertNull(options.getReadLimit(), "readLimit default should be null"); - } - - @Test - @DisplayName("Test readLimit of 0") - void testReadLimitZero() { - // 0 should be allowed (means don't read any data) - LanceOptions options = LanceOptions.builder() - .path("/test/path") - .readLimit(0L) - .build(); - - assertEquals(0L, options.getReadLimit()); - } - - @Test - @DisplayName("Test negative readLimit should fail") - void testNegativeReadLimit() { - assertThrows(IllegalArgumentException.class, () -> { - LanceOptions.builder() - .path("/test/path") - .readLimit(-1L) - .build(); - }, "Negative readLimit should throw exception"); - } + @Test + @DisplayName("Test IN predicate push-down") + void testInPredicatePushDown() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + // Create status IN ('active', 'pending', 'completed') expression + FieldReferenceExpression fieldRef = + new FieldReferenceExpression("status", DataTypes.STRING(), 0, 2); + ValueLiteralExpression value1 = new ValueLiteralExpression("active"); + ValueLiteralExpression value2 = new ValueLiteralExpression("pending"); + ValueLiteralExpression value3 = new ValueLiteralExpression("completed"); + + CallExpression inExpr = + CallExpression.permanent( + BuiltInFunctionDefinitions.IN, + Arrays.asList(fieldRef, value1, value2, value3), + DataTypes.BOOLEAN()); + + SupportsFilterPushDown.Result result = source.applyFilters(Collections.singletonList(inExpr)); + + assertEquals(1, result.getAcceptedFilters().size(), "IN predicate should be accepted"); } - // ==================== Helper Methods ==================== + @Test + @DisplayName("Test multiple independent filter conditions") + void testMultipleFilters() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + // Create multiple independent filter conditions + List filter1 = createEqualsFilter("status", "active"); + List filter2 = + createComparisonFilter("score", 60.0, BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL); - /** - * Create equals comparison filter expression - */ - private List createEqualsFilter(String fieldName, String value) { - ResolvedExpression expr = createEqualsExpression(fieldName, value); - return Collections.singletonList(expr); + List allFilters = new ArrayList<>(); + allFilters.addAll(filter1); + allFilters.addAll(filter2); + + SupportsFilterPushDown.Result result = source.applyFilters(allFilters); + + assertEquals( + 2, result.getAcceptedFilters().size(), "Two filter conditions should be accepted"); + assertEquals(0, result.getRemainingFilters().size(), "Should not have remaining filters"); } - /** - * Create equals comparison expression - */ - private ResolvedExpression createEqualsExpression(String fieldName, String value) { - FieldReferenceExpression fieldRef = new FieldReferenceExpression( - fieldName, DataTypes.STRING(), 0, getFieldIndex(fieldName)); - ValueLiteralExpression literal = new ValueLiteralExpression(value); - - return CallExpression.permanent( - BuiltInFunctionDefinitions.EQUALS, - Arrays.asList(fieldRef, literal), - DataTypes.BOOLEAN() - ); + @Test + @DisplayName("Test copy preserves filter conditions") + void testCopyPreservesFilters() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + // Apply filter conditions + List filters = createEqualsFilter("status", "active"); + source.applyFilters(filters); + + LanceDynamicTableSource copied = (LanceDynamicTableSource) source.copy(); + + // Verify copied source preserves filter conditions + assertNotNull(copied, "copy() should succeed"); } + } - /** - * Create comparison filter expression - */ - private List createComparisonFilter(String fieldName, Double value, BuiltInFunctionDefinition funcDef) { - ResolvedExpression expr = createComparisonExpression(fieldName, value, funcDef); - return Collections.singletonList(expr); + // ==================== Column Pruning Tests ==================== + + @Nested + @DisplayName("Column Pruning Tests") + class ProjectionPushDownTests { + + @Test + @DisplayName("Test single column projection") + void testSingleColumnProjection() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + // Select only id column + int[][] projection = {{0}}; // First column + source.applyProjection(projection); + + // Verify projection is applied + assertNotNull(source, "Projection should be successfully applied"); } - /** - * Create comparison expression - */ - private ResolvedExpression createComparisonExpression(String fieldName, Double value, BuiltInFunctionDefinition funcDef) { - FieldReferenceExpression fieldRef = new FieldReferenceExpression( - fieldName, DataTypes.DOUBLE(), 0, getFieldIndex(fieldName)); - ValueLiteralExpression literal = new ValueLiteralExpression(value); - - return CallExpression.permanent( - funcDef, - Arrays.asList(fieldRef, literal), - DataTypes.BOOLEAN() - ); + @Test + @DisplayName("Test multiple column projection") + void testMultipleColumnProjection() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + // Select id, name, score columns + int[][] projection = {{0}, {1}, {3}}; + source.applyProjection(projection); + + assertNotNull(source, "Multiple column projection should be successfully applied"); + } + + @Test + @DisplayName("Test nested projection not supported") + void testNestedProjectionNotSupported() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + assertFalse(source.supportsNestedProjection(), "Should not support nested projection"); + } + + @Test + @DisplayName("Test copy preserves projection") + void testCopyPreservesProjection() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + int[][] projection = {{0}, {2}}; + source.applyProjection(projection); + + LanceDynamicTableSource copied = (LanceDynamicTableSource) source.copy(); + + assertNotNull(copied, "copy() should preserve projection information"); + } + } + + // ==================== Combined Tests ==================== + + @Nested + @DisplayName("Combined Optimization Tests") + class CombinedOptimizationsTests { + + @Test + @DisplayName("Test Limit + filter condition combination") + void testLimitWithFilter() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + // Apply filter condition + List filters = createEqualsFilter("status", "active"); + source.applyFilters(filters); + + // Apply limit + source.applyLimit(100L); + + assertEquals(Long.valueOf(100L), source.getLimit(), "Limit should be correctly set"); } - /** - * Get field index - */ - private int getFieldIndex(String fieldName) { - switch (fieldName) { - case "id": return 0; - case "name": return 1; - case "status": return 2; - case "score": return 3; - case "created_time": return 4; - default: return 0; - } + @Test + @DisplayName("Test Limit + projection combination") + void testLimitWithProjection() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + // Apply projection + int[][] projection = {{0}, {1}}; + source.applyProjection(projection); + + // Apply limit + source.applyLimit(50L); + + assertEquals(Long.valueOf(50L), source.getLimit(), "Limit should be correctly set"); + } + + @Test + @DisplayName("Test all optimizations combined") + void testAllOptimizations() { + LanceDynamicTableSource source = new LanceDynamicTableSource(baseOptions, physicalDataType); + + // 1. Apply projection + int[][] projection = {{0}, {1}, {3}}; // id, name, score + source.applyProjection(projection); + + // 2. Apply filter condition + List filters = + createComparisonFilter("score", 60.0, BuiltInFunctionDefinitions.GREATER_THAN_OR_EQUAL); + SupportsFilterPushDown.Result result = source.applyFilters(filters); + + // 3. Apply limit + source.applyLimit(100L); + + // Verify all optimizations are correctly applied + assertEquals(1, result.getAcceptedFilters().size(), "Filter condition should be accepted"); + assertEquals(Long.valueOf(100L), source.getLimit(), "Limit should be correctly set"); + } + } + + // ==================== LanceOptions Tests ==================== + + @Nested + @DisplayName("LanceOptions Limit Configuration Tests") + class LanceOptionsLimitTests { + + @Test + @DisplayName("Test readLimit configuration") + void testReadLimitConfig() { + LanceOptions options = LanceOptions.builder().path("/test/path").readLimit(500L).build(); + + assertEquals(500L, options.getReadLimit(), "readLimit should be correctly configured"); + } + + @Test + @DisplayName("Test readLimit default value") + void testReadLimitDefault() { + LanceOptions options = LanceOptions.builder().path("/test/path").build(); + + assertNull(options.getReadLimit(), "readLimit default should be null"); + } + + @Test + @DisplayName("Test readLimit of 0") + void testReadLimitZero() { + // 0 should be allowed (means don't read any data) + LanceOptions options = LanceOptions.builder().path("/test/path").readLimit(0L).build(); + + assertEquals(0L, options.getReadLimit()); + } + + @Test + @DisplayName("Test negative readLimit should fail") + void testNegativeReadLimit() { + assertThrows( + IllegalArgumentException.class, + () -> { + LanceOptions.builder().path("/test/path").readLimit(-1L).build(); + }, + "Negative readLimit should throw exception"); + } + } + + // ==================== Helper Methods ==================== + + /** Create equals comparison filter expression */ + private List createEqualsFilter(String fieldName, String value) { + ResolvedExpression expr = createEqualsExpression(fieldName, value); + return Collections.singletonList(expr); + } + + /** Create equals comparison expression */ + private ResolvedExpression createEqualsExpression(String fieldName, String value) { + FieldReferenceExpression fieldRef = + new FieldReferenceExpression(fieldName, DataTypes.STRING(), 0, getFieldIndex(fieldName)); + ValueLiteralExpression literal = new ValueLiteralExpression(value); + + return CallExpression.permanent( + BuiltInFunctionDefinitions.EQUALS, Arrays.asList(fieldRef, literal), DataTypes.BOOLEAN()); + } + + /** Create comparison filter expression */ + private List createComparisonFilter( + String fieldName, Double value, BuiltInFunctionDefinition funcDef) { + ResolvedExpression expr = createComparisonExpression(fieldName, value, funcDef); + return Collections.singletonList(expr); + } + + /** Create comparison expression */ + private ResolvedExpression createComparisonExpression( + String fieldName, Double value, BuiltInFunctionDefinition funcDef) { + FieldReferenceExpression fieldRef = + new FieldReferenceExpression(fieldName, DataTypes.DOUBLE(), 0, getFieldIndex(fieldName)); + ValueLiteralExpression literal = new ValueLiteralExpression(value); + + return CallExpression.permanent(funcDef, Arrays.asList(fieldRef, literal), DataTypes.BOOLEAN()); + } + + /** Get field index */ + private int getFieldIndex(String fieldName) { + switch (fieldName) { + case "id": + return 0; + case "name": + return 1; + case "status": + return 2; + case "score": + return 3; + case "created_time": + return 4; + default: + return 0; } + } } diff --git a/src/test/java/org/apache/flink/connector/lance/table/LanceSqlITCase.java b/src/test/java/org/apache/flink/connector/lance/table/LanceSqlITCase.java index f121e43..c5c75db 100644 --- a/src/test/java/org/apache/flink/connector/lance/table/LanceSqlITCase.java +++ b/src/test/java/org/apache/flink/connector/lance/table/LanceSqlITCase.java @@ -1,11 +1,7 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -15,12 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.connector.lance.table; import org.apache.flink.connector.lance.config.LanceOptions; import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.api.Schema; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.ArrayType; import org.apache.flink.table.types.logical.BigIntType; @@ -43,305 +37,289 @@ import static org.assertj.core.api.Assertions.assertThat; -/** - * Lance SQL integration tests. - */ +/** Lance SQL integration tests. */ class LanceSqlITCase { - @TempDir - Path tempDir; - - private String datasetPath; - private String warehousePath; - - @BeforeEach - void setUp() { - datasetPath = tempDir.resolve("test_sql_dataset").toString(); - warehousePath = tempDir.resolve("test_warehouse").toString(); - } - - @Test - @DisplayName("Test LanceDynamicTableFactory identifier") - void testFactoryIdentifier() { - LanceDynamicTableFactory factory = new LanceDynamicTableFactory(); - assertThat(factory.factoryIdentifier()).isEqualTo("lance"); - } - - @Test - @DisplayName("Test LanceDynamicTableFactory required options") - void testRequiredOptions() { - LanceDynamicTableFactory factory = new LanceDynamicTableFactory(); - Set requiredOptionKeys = new HashSet<>(); - factory.requiredOptions().forEach(opt -> requiredOptionKeys.add(opt.key())); - - assertThat(requiredOptionKeys).contains("path"); - } - - @Test - @DisplayName("Test LanceDynamicTableFactory optional options") - void testOptionalOptions() { - LanceDynamicTableFactory factory = new LanceDynamicTableFactory(); - Set optionalOptionKeys = new HashSet<>(); - factory.optionalOptions().forEach(opt -> optionalOptionKeys.add(opt.key())); - - assertThat(optionalOptionKeys).contains( - "read.batch-size", - "read.columns", - "read.filter", - "write.batch-size", - "write.mode", - "write.max-rows-per-file", - "index.type", - "index.column", - "vector.column", - "vector.metric" - ); - } - - @Test - @DisplayName("Test LanceDynamicTableSource creation") - void testDynamicTableSourceCreation() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .readBatchSize(512) - .build(); - - List fields = new ArrayList<>(); - fields.add(new RowType.RowField("id", new BigIntType())); - fields.add(new RowType.RowField("content", new VarCharType())); - fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); - RowType rowType = new RowType(fields); - - DataType dataType = DataTypes.ROW( - DataTypes.FIELD("id", DataTypes.BIGINT()), - DataTypes.FIELD("content", DataTypes.STRING()), - DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT())) - ); - - LanceDynamicTableSource source = new LanceDynamicTableSource(options, dataType); - - assertThat(source.getOptions()).isEqualTo(options); - assertThat(source.getPhysicalDataType()).isEqualTo(dataType); - assertThat(source.asSummaryString()).isEqualTo("Lance Table Source"); - } - - @Test - @DisplayName("Test LanceDynamicTableSink creation") - void testDynamicTableSinkCreation() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .writeBatchSize(256) - .writeMode(LanceOptions.WriteMode.APPEND) - .build(); - - DataType dataType = DataTypes.ROW( - DataTypes.FIELD("id", DataTypes.BIGINT()), - DataTypes.FIELD("content", DataTypes.STRING()), - DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT())) - ); - - LanceDynamicTableSink sink = new LanceDynamicTableSink(options, dataType); - - assertThat(sink.getOptions()).isEqualTo(options); - assertThat(sink.getPhysicalDataType()).isEqualTo(dataType); - assertThat(sink.asSummaryString()).isEqualTo("Lance Table Sink"); - } - - @Test - @DisplayName("Test LanceDynamicTableSource copy") - void testDynamicTableSourceCopy() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .build(); - - DataType dataType = DataTypes.ROW( - DataTypes.FIELD("id", DataTypes.BIGINT()) - ); - - LanceDynamicTableSource source = new LanceDynamicTableSource(options, dataType); - LanceDynamicTableSource copiedSource = (LanceDynamicTableSource) source.copy(); - - assertThat(copiedSource).isNotSameAs(source); - assertThat(copiedSource.getOptions()).isEqualTo(source.getOptions()); - } - - @Test - @DisplayName("Test LanceDynamicTableSink copy") - void testDynamicTableSinkCopy() { - LanceOptions options = LanceOptions.builder() - .path(datasetPath) - .build(); - - DataType dataType = DataTypes.ROW( - DataTypes.FIELD("id", DataTypes.BIGINT()) - ); - - LanceDynamicTableSink sink = new LanceDynamicTableSink(options, dataType); - LanceDynamicTableSink copiedSink = (LanceDynamicTableSink) sink.copy(); - - assertThat(copiedSink).isNotSameAs(sink); - assertThat(copiedSink.getOptions()).isEqualTo(sink.getOptions()); - } - - @Test - @DisplayName("Test LanceCatalogFactory identifier") - void testCatalogFactoryIdentifier() { - LanceCatalogFactory factory = new LanceCatalogFactory(); - assertThat(factory.factoryIdentifier()).isEqualTo("lance"); - } - - @Test - @DisplayName("Test LanceCatalogFactory required options") - void testCatalogRequiredOptions() { - LanceCatalogFactory factory = new LanceCatalogFactory(); - Set requiredOptionKeys = new HashSet<>(); - factory.requiredOptions().forEach(opt -> requiredOptionKeys.add(opt.key())); - - assertThat(requiredOptionKeys).contains("warehouse"); - } - - @Test - @DisplayName("Test LanceCatalogFactory optional options") - void testCatalogOptionalOptions() { - LanceCatalogFactory factory = new LanceCatalogFactory(); - Set optionalOptionKeys = new HashSet<>(); - factory.optionalOptions().forEach(opt -> optionalOptionKeys.add(opt.key())); - - assertThat(optionalOptionKeys).contains("default-database"); - } - - @Test - @DisplayName("Test LanceCatalog creation and basic operations") - void testLanceCatalogBasicOperations() throws Exception { - LanceCatalog catalog = new LanceCatalog("test_catalog", "default", warehousePath); - - try { - catalog.open(); - - // Verify default database exists - assertThat(catalog.databaseExists("default")).isTrue(); - - // List databases - List databases = catalog.listDatabases(); - assertThat(databases).contains("default"); - - // Create new database - catalog.createDatabase("test_db", null, false); - assertThat(catalog.databaseExists("test_db")).isTrue(); - - // List tables (empty) - List tables = catalog.listTables("test_db"); - assertThat(tables).isEmpty(); - - // Drop database - catalog.dropDatabase("test_db", false, true); - assertThat(catalog.databaseExists("test_db")).isFalse(); - - } finally { - catalog.close(); - } - } - - @Test - @DisplayName("Test LanceCatalog warehouse path") - void testLanceCatalogWarehouse() throws Exception { - LanceCatalog catalog = new LanceCatalog("test", "default", warehousePath); - - try { - catalog.open(); - assertThat(catalog.getWarehouse()).isEqualTo(warehousePath); - } finally { - catalog.close(); - } - } - - @Test - @DisplayName("Test configuration options definition") - void testConfigOptions() { - assertThat(LanceDynamicTableFactory.PATH.key()).isEqualTo("path"); - assertThat(LanceDynamicTableFactory.READ_BATCH_SIZE.key()).isEqualTo("read.batch-size"); - assertThat(LanceDynamicTableFactory.READ_BATCH_SIZE.defaultValue()).isEqualTo(1024); - assertThat(LanceDynamicTableFactory.WRITE_BATCH_SIZE.key()).isEqualTo("write.batch-size"); - assertThat(LanceDynamicTableFactory.WRITE_MODE.key()).isEqualTo("write.mode"); - assertThat(LanceDynamicTableFactory.WRITE_MODE.defaultValue()).isEqualTo("append"); - assertThat(LanceDynamicTableFactory.INDEX_TYPE.key()).isEqualTo("index.type"); - assertThat(LanceDynamicTableFactory.INDEX_TYPE.defaultValue()).isEqualTo("IVF_PQ"); - assertThat(LanceDynamicTableFactory.VECTOR_METRIC.key()).isEqualTo("vector.metric"); - assertThat(LanceDynamicTableFactory.VECTOR_METRIC.defaultValue()).isEqualTo("L2"); - } - - @Test - @DisplayName("Test Catalog configuration options definition") - void testCatalogConfigOptions() { - assertThat(LanceCatalogFactory.WAREHOUSE.key()).isEqualTo("warehouse"); - assertThat(LanceCatalogFactory.DEFAULT_DATABASE.key()).isEqualTo("default-database"); - assertThat(LanceCatalogFactory.DEFAULT_DATABASE.defaultValue()).isEqualTo("default"); - } - - @Test - @DisplayName("Test S3 Catalog configuration options definition") - void testS3CatalogConfigOptions() { - // S3 configuration options - assertThat(LanceCatalogFactory.S3_ACCESS_KEY.key()).isEqualTo("s3-access-key"); - assertThat(LanceCatalogFactory.S3_SECRET_KEY.key()).isEqualTo("s3-secret-key"); - assertThat(LanceCatalogFactory.S3_REGION.key()).isEqualTo("s3-region"); - assertThat(LanceCatalogFactory.S3_ENDPOINT.key()).isEqualTo("s3-endpoint"); - assertThat(LanceCatalogFactory.S3_VIRTUAL_HOSTED_STYLE.key()).isEqualTo("s3-virtual-hosted-style"); - assertThat(LanceCatalogFactory.S3_ALLOW_HTTP.key()).isEqualTo("s3-allow-http"); - - // Default values - assertThat(LanceCatalogFactory.S3_VIRTUAL_HOSTED_STYLE.defaultValue()).isTrue(); - assertThat(LanceCatalogFactory.S3_ALLOW_HTTP.defaultValue()).isFalse(); - } - - @Test - @DisplayName("Test LanceCatalog S3 remote storage detection") - void testLanceCatalogRemoteStorageDetection() { - // S3 path should be identified as remote storage - LanceCatalog s3Catalog = new LanceCatalog("test", "default", "s3://bucket/path"); - assertThat(s3Catalog.isRemoteStorage()).isTrue(); - - // S3A path - LanceCatalog s3aCatalog = new LanceCatalog("test", "default", "s3a://bucket/path"); - assertThat(s3aCatalog.isRemoteStorage()).isTrue(); - - // GCS path - LanceCatalog gcsCatalog = new LanceCatalog("test", "default", "gs://bucket/path"); - assertThat(gcsCatalog.isRemoteStorage()).isTrue(); - - // Azure path - LanceCatalog azCatalog = new LanceCatalog("test", "default", "az://container/path"); - assertThat(azCatalog.isRemoteStorage()).isTrue(); - - // Local path should be identified as local storage - LanceCatalog localCatalog = new LanceCatalog("test", "default", warehousePath); - assertThat(localCatalog.isRemoteStorage()).isFalse(); + @TempDir Path tempDir; + + private String datasetPath; + private String warehousePath; + + @BeforeEach + void setUp() { + datasetPath = tempDir.resolve("test_sql_dataset").toString(); + warehousePath = tempDir.resolve("test_warehouse").toString(); + } + + @Test + @DisplayName("Test LanceDynamicTableFactory identifier") + void testFactoryIdentifier() { + LanceDynamicTableFactory factory = new LanceDynamicTableFactory(); + assertThat(factory.factoryIdentifier()).isEqualTo("lance"); + } + + @Test + @DisplayName("Test LanceDynamicTableFactory required options") + void testRequiredOptions() { + LanceDynamicTableFactory factory = new LanceDynamicTableFactory(); + Set requiredOptionKeys = new HashSet<>(); + factory.requiredOptions().forEach(opt -> requiredOptionKeys.add(opt.key())); + + assertThat(requiredOptionKeys).contains("path"); + } + + @Test + @DisplayName("Test LanceDynamicTableFactory optional options") + void testOptionalOptions() { + LanceDynamicTableFactory factory = new LanceDynamicTableFactory(); + Set optionalOptionKeys = new HashSet<>(); + factory.optionalOptions().forEach(opt -> optionalOptionKeys.add(opt.key())); + + assertThat(optionalOptionKeys) + .contains( + "read.batch-size", + "read.columns", + "read.filter", + "write.batch-size", + "write.mode", + "write.max-rows-per-file", + "index.type", + "index.column", + "vector.column", + "vector.metric"); + } + + @Test + @DisplayName("Test LanceDynamicTableSource creation") + void testDynamicTableSourceCreation() { + LanceOptions options = LanceOptions.builder().path(datasetPath).readBatchSize(512).build(); + + List fields = new ArrayList<>(); + fields.add(new RowType.RowField("id", new BigIntType())); + fields.add(new RowType.RowField("content", new VarCharType())); + fields.add(new RowType.RowField("embedding", new ArrayType(new FloatType()))); + RowType rowType = new RowType(fields); + + DataType dataType = + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.BIGINT()), + DataTypes.FIELD("content", DataTypes.STRING()), + DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + LanceDynamicTableSource source = new LanceDynamicTableSource(options, dataType); + + assertThat(source.getOptions()).isEqualTo(options); + assertThat(source.getPhysicalDataType()).isEqualTo(dataType); + assertThat(source.asSummaryString()).isEqualTo("Lance Table Source"); + } + + @Test + @DisplayName("Test LanceDynamicTableSink creation") + void testDynamicTableSinkCreation() { + LanceOptions options = + LanceOptions.builder() + .path(datasetPath) + .writeBatchSize(256) + .writeMode(LanceOptions.WriteMode.APPEND) + .build(); + + DataType dataType = + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.BIGINT()), + DataTypes.FIELD("content", DataTypes.STRING()), + DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + LanceDynamicTableSink sink = new LanceDynamicTableSink(options, dataType); + + assertThat(sink.getOptions()).isEqualTo(options); + assertThat(sink.getPhysicalDataType()).isEqualTo(dataType); + assertThat(sink.asSummaryString()).isEqualTo("Lance Table Sink"); + } + + @Test + @DisplayName("Test LanceDynamicTableSource copy") + void testDynamicTableSourceCopy() { + LanceOptions options = LanceOptions.builder().path(datasetPath).build(); + + DataType dataType = DataTypes.ROW(DataTypes.FIELD("id", DataTypes.BIGINT())); + + LanceDynamicTableSource source = new LanceDynamicTableSource(options, dataType); + LanceDynamicTableSource copiedSource = (LanceDynamicTableSource) source.copy(); + + assertThat(copiedSource).isNotSameAs(source); + assertThat(copiedSource.getOptions()).isEqualTo(source.getOptions()); + } + + @Test + @DisplayName("Test LanceDynamicTableSink copy") + void testDynamicTableSinkCopy() { + LanceOptions options = LanceOptions.builder().path(datasetPath).build(); + + DataType dataType = DataTypes.ROW(DataTypes.FIELD("id", DataTypes.BIGINT())); + + LanceDynamicTableSink sink = new LanceDynamicTableSink(options, dataType); + LanceDynamicTableSink copiedSink = (LanceDynamicTableSink) sink.copy(); + + assertThat(copiedSink).isNotSameAs(sink); + assertThat(copiedSink.getOptions()).isEqualTo(sink.getOptions()); + } + + @Test + @DisplayName("Test LanceCatalogFactory identifier") + void testCatalogFactoryIdentifier() { + LanceCatalogFactory factory = new LanceCatalogFactory(); + assertThat(factory.factoryIdentifier()).isEqualTo("lance"); + } + + @Test + @DisplayName("Test LanceCatalogFactory required options") + void testCatalogRequiredOptions() { + LanceCatalogFactory factory = new LanceCatalogFactory(); + Set requiredOptionKeys = new HashSet<>(); + factory.requiredOptions().forEach(opt -> requiredOptionKeys.add(opt.key())); + + assertThat(requiredOptionKeys).contains("warehouse"); + } + + @Test + @DisplayName("Test LanceCatalogFactory optional options") + void testCatalogOptionalOptions() { + LanceCatalogFactory factory = new LanceCatalogFactory(); + Set optionalOptionKeys = new HashSet<>(); + factory.optionalOptions().forEach(opt -> optionalOptionKeys.add(opt.key())); + + assertThat(optionalOptionKeys).contains("default-database"); + } + + @Test + @DisplayName("Test LanceCatalog creation and basic operations") + void testLanceCatalogBasicOperations() throws Exception { + LanceCatalog catalog = new LanceCatalog("test_catalog", "default", warehousePath); + + try { + catalog.open(); + + // Verify default database exists + assertThat(catalog.databaseExists("default")).isTrue(); + + // List databases + List databases = catalog.listDatabases(); + assertThat(databases).contains("default"); + + // Create new database + catalog.createDatabase("test_db", null, false); + assertThat(catalog.databaseExists("test_db")).isTrue(); + + // List tables (empty) + List tables = catalog.listTables("test_db"); + assertThat(tables).isEmpty(); + + // Drop database + catalog.dropDatabase("test_db", false, true); + assertThat(catalog.databaseExists("test_db")).isFalse(); + + } finally { + catalog.close(); } - - @Test - @DisplayName("Test LanceCatalog construction with storage options") - void testLanceCatalogWithStorageOptions() { - Map storageOptions = new HashMap<>(); - storageOptions.put("aws_access_key_id", "test-key"); - storageOptions.put("aws_secret_access_key", "test-secret"); - storageOptions.put("aws_region", "us-east-1"); - - LanceCatalog catalog = new LanceCatalog( - "test_catalog", - "default", - "s3://bucket/warehouse", - storageOptions - ); - - assertThat(catalog.getStorageOptions()).containsEntry("aws_access_key_id", "test-key"); - assertThat(catalog.getStorageOptions()).containsEntry("aws_secret_access_key", "test-secret"); - assertThat(catalog.getStorageOptions()).containsEntry("aws_region", "us-east-1"); - } - - @Test - @DisplayName("Test vector search UDF configuration") - void testVectorSearchFunctionConfiguration() { - LanceVectorSearchFunction function = new LanceVectorSearchFunction(); - assertThat(function).isNotNull(); + } + + @Test + @DisplayName("Test LanceCatalog warehouse path") + void testLanceCatalogWarehouse() throws Exception { + LanceCatalog catalog = new LanceCatalog("test", "default", warehousePath); + + try { + catalog.open(); + assertThat(catalog.getWarehouse()).isEqualTo(warehousePath); + } finally { + catalog.close(); } + } + + @Test + @DisplayName("Test configuration options definition") + void testConfigOptions() { + assertThat(LanceDynamicTableFactory.PATH.key()).isEqualTo("path"); + assertThat(LanceDynamicTableFactory.READ_BATCH_SIZE.key()).isEqualTo("read.batch-size"); + assertThat(LanceDynamicTableFactory.READ_BATCH_SIZE.defaultValue()).isEqualTo(1024); + assertThat(LanceDynamicTableFactory.WRITE_BATCH_SIZE.key()).isEqualTo("write.batch-size"); + assertThat(LanceDynamicTableFactory.WRITE_MODE.key()).isEqualTo("write.mode"); + assertThat(LanceDynamicTableFactory.WRITE_MODE.defaultValue()).isEqualTo("append"); + assertThat(LanceDynamicTableFactory.INDEX_TYPE.key()).isEqualTo("index.type"); + assertThat(LanceDynamicTableFactory.INDEX_TYPE.defaultValue()).isEqualTo("IVF_PQ"); + assertThat(LanceDynamicTableFactory.VECTOR_METRIC.key()).isEqualTo("vector.metric"); + assertThat(LanceDynamicTableFactory.VECTOR_METRIC.defaultValue()).isEqualTo("L2"); + } + + @Test + @DisplayName("Test Catalog configuration options definition") + void testCatalogConfigOptions() { + assertThat(LanceCatalogFactory.WAREHOUSE.key()).isEqualTo("warehouse"); + assertThat(LanceCatalogFactory.DEFAULT_DATABASE.key()).isEqualTo("default-database"); + assertThat(LanceCatalogFactory.DEFAULT_DATABASE.defaultValue()).isEqualTo("default"); + } + + @Test + @DisplayName("Test S3 Catalog configuration options definition") + void testS3CatalogConfigOptions() { + // S3 configuration options + assertThat(LanceCatalogFactory.S3_ACCESS_KEY.key()).isEqualTo("s3-access-key"); + assertThat(LanceCatalogFactory.S3_SECRET_KEY.key()).isEqualTo("s3-secret-key"); + assertThat(LanceCatalogFactory.S3_REGION.key()).isEqualTo("s3-region"); + assertThat(LanceCatalogFactory.S3_ENDPOINT.key()).isEqualTo("s3-endpoint"); + assertThat(LanceCatalogFactory.S3_VIRTUAL_HOSTED_STYLE.key()) + .isEqualTo("s3-virtual-hosted-style"); + assertThat(LanceCatalogFactory.S3_ALLOW_HTTP.key()).isEqualTo("s3-allow-http"); + + // Default values + assertThat(LanceCatalogFactory.S3_VIRTUAL_HOSTED_STYLE.defaultValue()).isTrue(); + assertThat(LanceCatalogFactory.S3_ALLOW_HTTP.defaultValue()).isFalse(); + } + + @Test + @DisplayName("Test LanceCatalog S3 remote storage detection") + void testLanceCatalogRemoteStorageDetection() { + // S3 path should be identified as remote storage + LanceCatalog s3Catalog = new LanceCatalog("test", "default", "s3://bucket/path"); + assertThat(s3Catalog.isRemoteStorage()).isTrue(); + + // S3A path + LanceCatalog s3aCatalog = new LanceCatalog("test", "default", "s3a://bucket/path"); + assertThat(s3aCatalog.isRemoteStorage()).isTrue(); + + // GCS path + LanceCatalog gcsCatalog = new LanceCatalog("test", "default", "gs://bucket/path"); + assertThat(gcsCatalog.isRemoteStorage()).isTrue(); + + // Azure path + LanceCatalog azCatalog = new LanceCatalog("test", "default", "az://container/path"); + assertThat(azCatalog.isRemoteStorage()).isTrue(); + + // Local path should be identified as local storage + LanceCatalog localCatalog = new LanceCatalog("test", "default", warehousePath); + assertThat(localCatalog.isRemoteStorage()).isFalse(); + } + + @Test + @DisplayName("Test LanceCatalog construction with storage options") + void testLanceCatalogWithStorageOptions() { + Map storageOptions = new HashMap<>(); + storageOptions.put("aws_access_key_id", "test-key"); + storageOptions.put("aws_secret_access_key", "test-secret"); + storageOptions.put("aws_region", "us-east-1"); + + LanceCatalog catalog = + new LanceCatalog("test_catalog", "default", "s3://bucket/warehouse", storageOptions); + + assertThat(catalog.getStorageOptions()).containsEntry("aws_access_key_id", "test-key"); + assertThat(catalog.getStorageOptions()).containsEntry("aws_secret_access_key", "test-secret"); + assertThat(catalog.getStorageOptions()).containsEntry("aws_region", "us-east-1"); + } + + @Test + @DisplayName("Test vector search UDF configuration") + void testVectorSearchFunctionConfiguration() { + LanceVectorSearchFunction function = new LanceVectorSearchFunction(); + assertThat(function).isNotNull(); + } } From 46ca1fa75b3196830f2cd1746a3cc2b8d32ab6de Mon Sep 17 00:00:00 2001 From: Vova Kolmakov Date: Wed, 22 Apr 2026 15:44:09 +0700 Subject: [PATCH 3/3] ci: add Makefile and GitHub Actions workflow for PR checks Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/flink.yml | 48 +++++++++++++++++++++++++++++++++++++ Makefile | 44 ++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 .github/workflows/flink.yml create mode 100644 Makefile diff --git a/.github/workflows/flink.yml b/.github/workflows/flink.yml new file mode 100644 index 0000000..51e0465 --- /dev/null +++ b/.github/workflows/flink.yml @@ -0,0 +1,48 @@ +name: Flink CI + +on: + push: + branches: [main] + paths-ignore: + - '**.md' + - 'docs/**' + pull_request: + types: [opened, synchronize, ready_for_review, reopened] + paths-ignore: + - '**.md' + - 'docs/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + +jobs: + lint: + runs-on: ubuntu-24.04 + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + - name: Set up JDK + uses: actions/setup-java@v4 + with: + distribution: temurin + java-version: '8' + cache: maven + - name: Run lint + run: make lint + + build-and-test: + runs-on: ubuntu-24.04 + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + - name: Set up JDK + uses: actions/setup-java@v4 + with: + distribution: temurin + java-version: '8' + cache: maven + - name: Install + run: make install + - name: Test + run: make test diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3edece7 --- /dev/null +++ b/Makefile @@ -0,0 +1,44 @@ +SHELL := /bin/bash +MAVEN := mvn + +.PHONY: install test build bundle clean lint format \ + install-all test-all build-all help + +help: + @echo "Available targets:" + @echo " install Install artifact to local .m2 (skip tests)" + @echo " test Run unit and integration tests" + @echo " build Run lint + install" + @echo " bundle Build shaded fat-jar (skip tests)" + @echo " clean Remove target/" + @echo " lint Run Spotless and Checkstyle checks" + @echo " format Apply Spotless formatting in place" + @echo " install-all Alias of install (reserved for future multi-module matrix)" + @echo " test-all Alias of test (reserved for future multi-module matrix)" + @echo " build-all Alias of build (reserved for future multi-module matrix)" + +install: + $(MAVEN) install -DskipTests + +test: + $(MAVEN) test + +build: + $(MAKE) lint + $(MAKE) install + +bundle: + $(MAVEN) package -DskipTests + +clean: + $(MAVEN) clean + +lint: + $(MAVEN) spotless:check checkstyle:check + +format: + $(MAVEN) spotless:apply + +install-all: install +test-all: test +build-all: build