Executes aggregate computation at the data source, supporting COUNT, SUM, AVG, MIN, MAX and other aggregate functions. + *
Executes aggregate computation at the data source, supporting COUNT, SUM, AVG, MIN, MAX and + * other aggregate functions. * *
Usage example: + * *
{@code
* AggregateInfo aggInfo = AggregateInfo.builder()
* .addCountStar("cnt")
@@ -63,263 +59,249 @@
*/
public class LanceAggregateSource extends RichParallelSourceFunction {
- private static final long serialVersionUID = 1L;
- private static final Logger LOG = LoggerFactory.getLogger(LanceAggregateSource.class);
-
- private final LanceOptions options;
- private final RowType sourceRowType;
- private final AggregateInfo aggregateInfo;
- private final String[] selectedColumns;
-
- private transient volatile boolean running;
- private transient BufferAllocator allocator;
- private transient Dataset dataset;
- private transient RowDataConverter converter;
- private transient AggregateExecutor aggregateExecutor;
-
- /**
- * Create LanceAggregateSource
- *
- * @param options Lance configuration options
- * @param sourceRowType RowType of source table
- * @param aggregateInfo Aggregate information
- */
- public LanceAggregateSource(LanceOptions options, RowType sourceRowType, AggregateInfo aggregateInfo) {
- this.options = options;
- this.sourceRowType = sourceRowType;
- this.aggregateInfo = aggregateInfo;
-
- // Calculate columns to read
- List requiredColumns = aggregateInfo.getRequiredColumns();
- this.selectedColumns = requiredColumns.isEmpty() ? null : requiredColumns.toArray(new String[0]);
+ private static final long serialVersionUID = 1L;
+ private static final Logger LOG = LoggerFactory.getLogger(LanceAggregateSource.class);
+
+ private final LanceOptions options;
+ private final RowType sourceRowType;
+ private final AggregateInfo aggregateInfo;
+ private final String[] selectedColumns;
+
+ private transient volatile boolean running;
+ private transient BufferAllocator allocator;
+ private transient Dataset dataset;
+ private transient RowDataConverter converter;
+ private transient AggregateExecutor aggregateExecutor;
+
+ /**
+ * Create LanceAggregateSource
+ *
+ * @param options Lance configuration options
+ * @param sourceRowType RowType of source table
+ * @param aggregateInfo Aggregate information
+ */
+ public LanceAggregateSource(
+ LanceOptions options, RowType sourceRowType, AggregateInfo aggregateInfo) {
+ this.options = options;
+ this.sourceRowType = sourceRowType;
+ this.aggregateInfo = aggregateInfo;
+
+ // Calculate columns to read
+ List requiredColumns = aggregateInfo.getRequiredColumns();
+ this.selectedColumns =
+ requiredColumns.isEmpty() ? null : requiredColumns.toArray(new String[0]);
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+
+ LOG.info("Opening Lance aggregate data source: {}", options.getPath());
+ LOG.info("Aggregate info: {}", aggregateInfo);
+
+ this.running = true;
+ this.allocator = new RootAllocator(Long.MAX_VALUE);
+
+ // Open Lance dataset
+ String datasetPath = options.getPath();
+ if (datasetPath == null || datasetPath.isEmpty()) {
+ throw new IllegalArgumentException("Lance dataset path cannot be empty");
}
- @Override
- public void open(Configuration parameters) throws Exception {
- super.open(parameters);
+ try {
+ this.dataset = Dataset.open(datasetPath, allocator);
+ } catch (Exception e) {
+ throw new IOException("Failed to open Lance dataset: " + datasetPath, e);
+ }
- LOG.info("Opening Lance aggregate data source: {}", options.getPath());
- LOG.info("Aggregate info: {}", aggregateInfo);
+ // Initialize RowDataConverter (using source table Schema)
+ RowType actualRowType = this.sourceRowType;
+ if (actualRowType == null) {
+ Schema arrowSchema = dataset.getSchema();
+ actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
+ }
+ this.converter = new RowDataConverter(actualRowType);
- this.running = true;
- this.allocator = new RootAllocator(Long.MAX_VALUE);
+ // Initialize aggregate executor
+ this.aggregateExecutor = new AggregateExecutor(aggregateInfo, actualRowType);
+ this.aggregateExecutor.init();
- // Open Lance dataset
- String datasetPath = options.getPath();
- if (datasetPath == null || datasetPath.isEmpty()) {
- throw new IllegalArgumentException("Lance dataset path cannot be empty");
- }
+ LOG.info("Lance aggregate data source opened");
+ }
- try {
- this.dataset = Dataset.open(datasetPath, allocator);
- } catch (Exception e) {
- throw new IOException("Failed to open Lance dataset: " + datasetPath, e);
- }
-
- // Initialize RowDataConverter (using source table Schema)
- RowType actualRowType = this.sourceRowType;
- if (actualRowType == null) {
- Schema arrowSchema = dataset.getSchema();
- actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
- }
- this.converter = new RowDataConverter(actualRowType);
+ @Override
+ public void run(SourceContext ctx) throws Exception {
+ LOG.info("Starting aggregate read from Lance dataset: {}", options.getPath());
- // Initialize aggregate executor
- this.aggregateExecutor = new AggregateExecutor(aggregateInfo, actualRowType);
- this.aggregateExecutor.init();
+ int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+ int numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
- LOG.info("Lance aggregate data source opened");
+ // Aggregate operation only executes on subtask 0 to avoid duplicate aggregation
+ if (subtaskIndex != 0) {
+ LOG.info("Subtask {} skipped (only subtask 0 executes in aggregate mode)", subtaskIndex);
+ return;
}
- @Override
- public void run(SourceContext ctx) throws Exception {
- LOG.info("Starting aggregate read from Lance dataset: {}", options.getPath());
-
- int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
- int numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
+ String filter = options.getReadFilter();
- // Aggregate operation only executes on subtask 0 to avoid duplicate aggregation
- if (subtaskIndex != 0) {
- LOG.info("Subtask {} skipped (only subtask 0 executes in aggregate mode)", subtaskIndex);
- return;
- }
+ // Read all data and perform aggregation
+ if (filter != null && !filter.isEmpty()) {
+ readAndAggregateWithFilter(ctx);
+ } else {
+ readAndAggregateAll(ctx);
+ }
- String filter = options.getReadFilter();
+ LOG.info("Lance aggregate data source read completed");
+ }
- // Read all data and perform aggregation
- if (filter != null && !filter.isEmpty()) {
- readAndAggregateWithFilter(ctx);
- } else {
- readAndAggregateAll(ctx);
- }
+ /** Aggregate read with filter condition */
+ private void readAndAggregateWithFilter(SourceContext ctx) throws Exception {
+ ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
+ scanOptionsBuilder.batchSize(options.getReadBatchSize());
- LOG.info("Lance aggregate data source read completed");
+ if (selectedColumns != null && selectedColumns.length > 0) {
+ scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
}
- /**
- * Aggregate read with filter condition
- */
- private void readAndAggregateWithFilter(SourceContext ctx) throws Exception {
- ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
- scanOptionsBuilder.batchSize(options.getReadBatchSize());
-
- if (selectedColumns != null && selectedColumns.length > 0) {
- scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
- }
-
- String filter = options.getReadFilter();
- if (filter != null && !filter.isEmpty()) {
- LOG.info("Applying filter condition: {}", filter);
- scanOptionsBuilder.filter(filter);
- }
+ String filter = options.getReadFilter();
+ if (filter != null && !filter.isEmpty()) {
+ LOG.info("Applying filter condition: {}", filter);
+ scanOptionsBuilder.filter(filter);
+ }
- ScanOptions scanOptions = scanOptionsBuilder.build();
+ ScanOptions scanOptions = scanOptionsBuilder.build();
- // Phase 1: Read data and accumulate aggregation
- try (LanceScanner scanner = dataset.newScan(scanOptions)) {
- try (ArrowReader reader = scanner.scanBatches()) {
- while (reader.loadNextBatch() && running) {
- VectorSchemaRoot root = reader.getVectorSchemaRoot();
- List rows = converter.toRowDataList(root);
+ // Phase 1: Read data and accumulate aggregation
+ try (LanceScanner scanner = dataset.newScan(scanOptions)) {
+ try (ArrowReader reader = scanner.scanBatches()) {
+ while (reader.loadNextBatch() && running) {
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+ List rows = converter.toRowDataList(root);
- for (RowData row : rows) {
- aggregateExecutor.accumulate(row);
- }
- }
- }
+ for (RowData row : rows) {
+ aggregateExecutor.accumulate(row);
+ }
}
-
- // Phase 2: Output aggregate results
- outputAggregateResults(ctx);
+ }
}
- /**
- * Read all data and aggregate (without filter condition)
- */
- private void readAndAggregateAll(SourceContext ctx) throws Exception {
- List fragments = dataset.getFragments();
- LOG.info("Dataset has {} Fragments", fragments.size());
-
- // Phase 1: Read all Fragments and accumulate aggregation
- for (Fragment fragment : fragments) {
- if (!running) {
- break;
- }
- readAndAggregateFragment(fragment);
- }
-
- // Phase 2: Output aggregate results
- outputAggregateResults(ctx);
+ // Phase 2: Output aggregate results
+ outputAggregateResults(ctx);
+ }
+
+ /** Read all data and aggregate (without filter condition) */
+ private void readAndAggregateAll(SourceContext ctx) throws Exception {
+ List fragments = dataset.getFragments();
+ LOG.info("Dataset has {} Fragments", fragments.size());
+
+ // Phase 1: Read all Fragments and accumulate aggregation
+ for (Fragment fragment : fragments) {
+ if (!running) {
+ break;
+ }
+ readAndAggregateFragment(fragment);
}
- /**
- * Read single Fragment and accumulate aggregation
- */
- private void readAndAggregateFragment(Fragment fragment) throws Exception {
- LOG.debug("Reading Fragment: {}", fragment.getId());
+ // Phase 2: Output aggregate results
+ outputAggregateResults(ctx);
+ }
- ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
- scanOptionsBuilder.batchSize(options.getReadBatchSize());
+ /** Read single Fragment and accumulate aggregation */
+ private void readAndAggregateFragment(Fragment fragment) throws Exception {
+ LOG.debug("Reading Fragment: {}", fragment.getId());
- if (selectedColumns != null && selectedColumns.length > 0) {
- scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
- }
+ ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
+ scanOptionsBuilder.batchSize(options.getReadBatchSize());
- ScanOptions scanOptions = scanOptionsBuilder.build();
+ if (selectedColumns != null && selectedColumns.length > 0) {
+ scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
+ }
- try (LanceScanner scanner = fragment.newScan(scanOptions)) {
- try (ArrowReader reader = scanner.scanBatches()) {
- while (reader.loadNextBatch() && running) {
- VectorSchemaRoot root = reader.getVectorSchemaRoot();
- List rows = converter.toRowDataList(root);
+ ScanOptions scanOptions = scanOptionsBuilder.build();
- for (RowData row : rows) {
- aggregateExecutor.accumulate(row);
- }
- }
- }
- }
- }
+ try (LanceScanner scanner = fragment.newScan(scanOptions)) {
+ try (ArrowReader reader = scanner.scanBatches()) {
+ while (reader.loadNextBatch() && running) {
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+ List rows = converter.toRowDataList(root);
- /**
- * Output aggregate results
- */
- private void outputAggregateResults(SourceContext ctx) {
- List results = aggregateExecutor.getResults();
- LOG.info("Aggregation completed, {} result rows", results.size());
-
- synchronized (ctx.getCheckpointLock()) {
- for (RowData result : results) {
- ctx.collect(result);
- }
+ for (RowData row : rows) {
+ aggregateExecutor.accumulate(row);
+ }
}
+ }
}
+ }
+
+ /** Output aggregate results */
+ private void outputAggregateResults(SourceContext ctx) {
+ List results = aggregateExecutor.getResults();
+ LOG.info("Aggregation completed, {} result rows", results.size());
- @Override
- public void cancel() {
- LOG.info("Cancelling Lance aggregate data source");
- this.running = false;
+ synchronized (ctx.getCheckpointLock()) {
+ for (RowData result : results) {
+ ctx.collect(result);
+ }
}
+ }
- @Override
- public void close() throws Exception {
- LOG.info("Closing Lance aggregate data source");
+ @Override
+ public void cancel() {
+ LOG.info("Cancelling Lance aggregate data source");
+ this.running = false;
+ }
- this.running = false;
+ @Override
+ public void close() throws Exception {
+ LOG.info("Closing Lance aggregate data source");
- if (aggregateExecutor != null) {
- aggregateExecutor.reset();
- }
+ this.running = false;
- if (dataset != null) {
- try {
- dataset.close();
- } catch (Exception e) {
- LOG.warn("Error closing Lance dataset", e);
- }
- dataset = null;
- }
-
- if (allocator != null) {
- try {
- allocator.close();
- } catch (Exception e) {
- LOG.warn("Error closing memory allocator", e);
- }
- allocator = null;
- }
-
- super.close();
+ if (aggregateExecutor != null) {
+ aggregateExecutor.reset();
}
- /**
- * Get aggregate information
- */
- public AggregateInfo getAggregateInfo() {
- return aggregateInfo;
+ if (dataset != null) {
+ try {
+ dataset.close();
+ } catch (Exception e) {
+ LOG.warn("Error closing Lance dataset", e);
+ }
+ dataset = null;
}
- /**
- * Get configuration options
- */
- public LanceOptions getOptions() {
- return options;
+ if (allocator != null) {
+ try {
+ allocator.close();
+ } catch (Exception e) {
+ LOG.warn("Error closing memory allocator", e);
+ }
+ allocator = null;
}
- /**
- * Get source table RowType
- */
- public RowType getSourceRowType() {
- return sourceRowType;
- }
+ super.close();
+ }
- /**
- * Get aggregate result RowType
- */
- public RowType getResultRowType() {
- if (aggregateExecutor != null) {
- return aggregateExecutor.buildResultRowType();
- }
- return null;
+ /** Get aggregate information */
+ public AggregateInfo getAggregateInfo() {
+ return aggregateInfo;
+ }
+
+ /** Get configuration options */
+ public LanceOptions getOptions() {
+ return options;
+ }
+
+ /** Get source table RowType */
+ public RowType getSourceRowType() {
+ return sourceRowType;
+ }
+
+ /** Get aggregate result RowType */
+ public RowType getResultRowType() {
+ if (aggregateExecutor != null) {
+ return aggregateExecutor.buildResultRowType();
}
+ return null;
+ }
}
diff --git a/src/main/java/org/apache/flink/connector/lance/LanceIndexBuilder.java b/src/main/java/org/apache/flink/connector/lance/LanceIndexBuilder.java
index 3f93897..59b780e 100644
--- a/src/main/java/org/apache/flink/connector/lance/LanceIndexBuilder.java
+++ b/src/main/java/org/apache/flink/connector/lance/LanceIndexBuilder.java
@@ -1,11 +1,7 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -15,7 +11,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.flink.connector.lance;
import org.apache.flink.connector.lance.config.LanceOptions;
@@ -41,10 +36,11 @@
/**
* Lance vector index builder.
- *
+ *
* Supports building IVF_PQ, IVF_HNSW_PQ, and IVF_FLAT vector indices.
- *
+ *
*
Usage example:
+ *
*
{@code
* LanceIndexBuilder builder = LanceIndexBuilder.builder()
* .datasetPath("/path/to/dataset")
@@ -53,384 +49,390 @@
* .numPartitions(256)
* .numSubVectors(16)
* .build();
- *
+ *
* IndexBuildResult result = builder.buildIndex();
* }
*/
public class LanceIndexBuilder implements Closeable, Serializable {
- private static final long serialVersionUID = 1L;
- private static final Logger LOG = LoggerFactory.getLogger(LanceIndexBuilder.class);
+ private static final long serialVersionUID = 1L;
+ private static final Logger LOG = LoggerFactory.getLogger(LanceIndexBuilder.class);
+
+ private final String datasetPath;
+ private final String columnName;
+ private final LanceOptions.IndexType indexType;
+ private final LanceOptions.MetricType metricType;
+ private final int numPartitions;
+ private final Integer numSubVectors;
+ private final int numBits;
+ private final int maxLevel;
+ private final int m;
+ private final int efConstruction;
+ private final boolean replace;
+
+ private transient BufferAllocator allocator;
+ private transient Dataset dataset;
+
+ private LanceIndexBuilder(Builder builder) {
+ this.datasetPath = builder.datasetPath;
+ this.columnName = builder.columnName;
+ this.indexType = builder.indexType;
+ this.metricType = builder.metricType;
+ this.numPartitions = builder.numPartitions;
+ this.numSubVectors = builder.numSubVectors;
+ this.numBits = builder.numBits;
+ this.maxLevel = builder.maxLevel;
+ this.m = builder.m;
+ this.efConstruction = builder.efConstruction;
+ this.replace = builder.replace;
+ }
+
+ /**
+ * Build vector index
+ *
+ * @return Index build result
+ */
+ public IndexBuildResult buildIndex() throws IOException {
+ LOG.info(
+ "Starting to build vector index, type: {}, column: {}, dataset: {}",
+ indexType,
+ columnName,
+ datasetPath);
+
+ long startTime = System.currentTimeMillis();
+
+ try {
+ // Initialize resources
+ this.allocator = new RootAllocator(Long.MAX_VALUE);
+ this.dataset = Dataset.open(datasetPath, allocator);
+
+ // Validate column exists
+ validateColumn();
+
+ // Get distance metric type
+ DistanceType distanceType = toDistanceType(metricType);
+
+ // Build IVF parameters
+ IvfBuildParams ivfParams =
+ new IvfBuildParams.Builder().setNumPartitions(numPartitions).build();
+
+ // Build index based on index type
+ IndexType lanceIndexType;
+ IndexParams indexParams;
+
+ switch (indexType) {
+ case IVF_PQ:
+ lanceIndexType = IndexType.IVF_PQ;
+ PQBuildParams pqParams =
+ new PQBuildParams.Builder()
+ .setNumSubVectors(numSubVectors != null ? numSubVectors : 16)
+ .setNumBits(numBits)
+ .build();
+ VectorIndexParams ivfPqParams =
+ VectorIndexParams.withIvfPqParams(distanceType, ivfParams, pqParams);
+ indexParams =
+ new IndexParams.Builder()
+ .setDistanceType(distanceType)
+ .setVectorIndexParams(ivfPqParams)
+ .build();
+ break;
+
+ case IVF_HNSW:
+ lanceIndexType = IndexType.IVF_HNSW_PQ;
+ HnswBuildParams hnswParams =
+ new HnswBuildParams.Builder()
+ .setMaxLevel((short) maxLevel)
+ .setM(m)
+ .setEfConstruction(efConstruction)
+ .build();
+ PQBuildParams hnswPqParams =
+ new PQBuildParams.Builder()
+ .setNumSubVectors(numSubVectors != null ? numSubVectors : 16)
+ .setNumBits(numBits)
+ .build();
+ VectorIndexParams ivfHnswParams =
+ VectorIndexParams.withIvfHnswPqParams(
+ distanceType, ivfParams, hnswParams, hnswPqParams);
+ indexParams =
+ new IndexParams.Builder()
+ .setDistanceType(distanceType)
+ .setVectorIndexParams(ivfHnswParams)
+ .build();
+ break;
+
+ case IVF_FLAT:
+ lanceIndexType = IndexType.IVF_FLAT;
+ VectorIndexParams ivfFlatParams = VectorIndexParams.ivfFlat(numPartitions, distanceType);
+ indexParams =
+ new IndexParams.Builder()
+ .setDistanceType(distanceType)
+ .setVectorIndexParams(ivfFlatParams)
+ .build();
+ break;
+
+ default:
+ throw new IllegalArgumentException("Unsupported index type: " + indexType);
+ }
+
+ // Create index
+ dataset.createIndex(
+ Collections.singletonList(columnName),
+ lanceIndexType,
+ Optional.empty(), // Index name, use default
+ indexParams,
+ replace);
+
+ long endTime = System.currentTimeMillis();
+ long duration = endTime - startTime;
+
+ LOG.info("Vector index build completed, duration: {} ms", duration);
+
+ return new IndexBuildResult(true, indexType, columnName, datasetPath, duration, null);
+ } catch (Exception e) {
+ LOG.error("Failed to build vector index", e);
+ return new IndexBuildResult(
+ false,
+ indexType,
+ columnName,
+ datasetPath,
+ System.currentTimeMillis() - startTime,
+ e.getMessage());
+ }
+ }
- private final String datasetPath;
- private final String columnName;
- private final LanceOptions.IndexType indexType;
- private final LanceOptions.MetricType metricType;
- private final int numPartitions;
- private final Integer numSubVectors;
- private final int numBits;
- private final int maxLevel;
- private final int m;
- private final int efConstruction;
- private final boolean replace;
-
- private transient BufferAllocator allocator;
- private transient Dataset dataset;
-
- private LanceIndexBuilder(Builder builder) {
- this.datasetPath = builder.datasetPath;
- this.columnName = builder.columnName;
- this.indexType = builder.indexType;
- this.metricType = builder.metricType;
- this.numPartitions = builder.numPartitions;
- this.numSubVectors = builder.numSubVectors;
- this.numBits = builder.numBits;
- this.maxLevel = builder.maxLevel;
- this.m = builder.m;
- this.efConstruction = builder.efConstruction;
- this.replace = builder.replace;
+ /** Validate vector column exists */
+ private void validateColumn() throws IOException {
+ // Check if column exists in Schema
+ boolean columnExists =
+ dataset.getSchema().getFields().stream()
+ .anyMatch(field -> field.getName().equals(columnName));
+
+ if (!columnExists) {
+ throw new IOException("Vector column does not exist: " + columnName);
+ }
+ }
+
+ /** Convert distance metric type */
+ private DistanceType toDistanceType(LanceOptions.MetricType metricType) {
+ switch (metricType) {
+ case L2:
+ return DistanceType.L2;
+ case COSINE:
+ return DistanceType.Cosine;
+ case DOT:
+ return DistanceType.Dot;
+ default:
+ return DistanceType.L2;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (dataset != null) {
+ try {
+ dataset.close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close dataset", e);
+ }
+ dataset = null;
}
- /**
- * Build vector index
- *
- * @return Index build result
- */
- public IndexBuildResult buildIndex() throws IOException {
- LOG.info("Starting to build vector index, type: {}, column: {}, dataset: {}",
- indexType, columnName, datasetPath);
-
- long startTime = System.currentTimeMillis();
-
- try {
- // Initialize resources
- this.allocator = new RootAllocator(Long.MAX_VALUE);
- this.dataset = Dataset.open(datasetPath, allocator);
-
- // Validate column exists
- validateColumn();
-
- // Get distance metric type
- DistanceType distanceType = toDistanceType(metricType);
-
- // Build IVF parameters
- IvfBuildParams ivfParams = new IvfBuildParams.Builder()
- .setNumPartitions(numPartitions)
- .build();
-
- // Build index based on index type
- IndexType lanceIndexType;
- IndexParams indexParams;
-
- switch (indexType) {
- case IVF_PQ:
- lanceIndexType = IndexType.IVF_PQ;
- PQBuildParams pqParams = new PQBuildParams.Builder()
- .setNumSubVectors(numSubVectors != null ? numSubVectors : 16)
- .setNumBits(numBits)
- .build();
- VectorIndexParams ivfPqParams = VectorIndexParams.withIvfPqParams(
- distanceType, ivfParams, pqParams);
- indexParams = new IndexParams.Builder()
- .setDistanceType(distanceType)
- .setVectorIndexParams(ivfPqParams)
- .build();
- break;
-
- case IVF_HNSW:
- lanceIndexType = IndexType.IVF_HNSW_PQ;
- HnswBuildParams hnswParams = new HnswBuildParams.Builder()
- .setMaxLevel((short) maxLevel)
- .setM(m)
- .setEfConstruction(efConstruction)
- .build();
- PQBuildParams hnswPqParams = new PQBuildParams.Builder()
- .setNumSubVectors(numSubVectors != null ? numSubVectors : 16)
- .setNumBits(numBits)
- .build();
- VectorIndexParams ivfHnswParams = VectorIndexParams.withIvfHnswPqParams(
- distanceType, ivfParams, hnswParams, hnswPqParams);
- indexParams = new IndexParams.Builder()
- .setDistanceType(distanceType)
- .setVectorIndexParams(ivfHnswParams)
- .build();
- break;
-
- case IVF_FLAT:
- lanceIndexType = IndexType.IVF_FLAT;
- VectorIndexParams ivfFlatParams = VectorIndexParams.ivfFlat(numPartitions, distanceType);
- indexParams = new IndexParams.Builder()
- .setDistanceType(distanceType)
- .setVectorIndexParams(ivfFlatParams)
- .build();
- break;
-
- default:
- throw new IllegalArgumentException("Unsupported index type: " + indexType);
- }
-
- // Create index
- dataset.createIndex(
- Collections.singletonList(columnName),
- lanceIndexType,
- Optional.empty(), // Index name, use default
- indexParams,
- replace
- );
-
- long endTime = System.currentTimeMillis();
- long duration = endTime - startTime;
-
- LOG.info("Vector index build completed, duration: {} ms", duration);
-
- return new IndexBuildResult(
- true,
- indexType,
- columnName,
- datasetPath,
- duration,
- null
- );
- } catch (Exception e) {
- LOG.error("Failed to build vector index", e);
- return new IndexBuildResult(
- false,
- indexType,
- columnName,
- datasetPath,
- System.currentTimeMillis() - startTime,
- e.getMessage()
- );
- }
+ if (allocator != null) {
+ try {
+ allocator.close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close allocator", e);
+ }
+ allocator = null;
+ }
+ }
+
+ /** Create builder */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** Create index builder from LanceOptions */
+ public static LanceIndexBuilder fromOptions(LanceOptions options) {
+ return builder()
+ .datasetPath(options.getPath())
+ .columnName(options.getIndexColumn())
+ .indexType(options.getIndexType())
+ .metricType(options.getVectorMetric())
+ .numPartitions(options.getIndexNumPartitions())
+ .numSubVectors(options.getIndexNumSubVectors())
+ .numBits(options.getIndexNumBits())
+ .maxLevel(options.getIndexMaxLevel())
+ .m(options.getIndexM())
+ .efConstruction(options.getIndexEfConstruction())
+ .build();
+ }
+
+ /** Builder */
+ public static class Builder {
+ private String datasetPath;
+ private String columnName;
+ private LanceOptions.IndexType indexType = LanceOptions.IndexType.IVF_PQ;
+ private LanceOptions.MetricType metricType = LanceOptions.MetricType.L2;
+ private int numPartitions = 256;
+ private Integer numSubVectors;
+ private int numBits = 8;
+ private int maxLevel = 7;
+ private int m = 16;
+ private int efConstruction = 100;
+ private boolean replace = false;
+
+ public Builder datasetPath(String datasetPath) {
+ this.datasetPath = datasetPath;
+ return this;
}
- /**
- * Validate vector column exists
- */
- private void validateColumn() throws IOException {
- // Check if column exists in Schema
- boolean columnExists = dataset.getSchema().getFields().stream()
- .anyMatch(field -> field.getName().equals(columnName));
-
- if (!columnExists) {
- throw new IOException("Vector column does not exist: " + columnName);
- }
+ public Builder columnName(String columnName) {
+ this.columnName = columnName;
+ return this;
}
- /**
- * Convert distance metric type
- */
- private DistanceType toDistanceType(LanceOptions.MetricType metricType) {
- switch (metricType) {
- case L2:
- return DistanceType.L2;
- case COSINE:
- return DistanceType.Cosine;
- case DOT:
- return DistanceType.Dot;
- default:
- return DistanceType.L2;
- }
+ public Builder indexType(LanceOptions.IndexType indexType) {
+ this.indexType = indexType;
+ return this;
}
- @Override
- public void close() throws IOException {
- if (dataset != null) {
- try {
- dataset.close();
- } catch (Exception e) {
- LOG.warn("Failed to close dataset", e);
- }
- dataset = null;
- }
-
- if (allocator != null) {
- try {
- allocator.close();
- } catch (Exception e) {
- LOG.warn("Failed to close allocator", e);
- }
- allocator = null;
- }
+ public Builder metricType(LanceOptions.MetricType metricType) {
+ this.metricType = metricType;
+ return this;
+ }
+
+ public Builder numPartitions(int numPartitions) {
+ this.numPartitions = numPartitions;
+ return this;
+ }
+
+ public Builder numSubVectors(Integer numSubVectors) {
+ this.numSubVectors = numSubVectors;
+ return this;
+ }
+
+ public Builder numBits(int numBits) {
+ this.numBits = numBits;
+ return this;
+ }
+
+ public Builder maxLevel(int maxLevel) {
+ this.maxLevel = maxLevel;
+ return this;
+ }
+
+ // checkstyle.off: MethodName
+ public Builder m(int m) {
+ this.m = m;
+ return this;
+ }
+
+ // checkstyle.on: MethodName
+
+ public Builder efConstruction(int efConstruction) {
+ this.efConstruction = efConstruction;
+ return this;
+ }
+
+ public Builder replace(boolean replace) {
+ this.replace = replace;
+ return this;
+ }
+
+ public LanceIndexBuilder build() {
+ validate();
+ return new LanceIndexBuilder(this);
+ }
+
+ private void validate() {
+ if (datasetPath == null || datasetPath.isEmpty()) {
+ throw new IllegalArgumentException("Dataset path cannot be empty");
+ }
+ if (columnName == null || columnName.isEmpty()) {
+ throw new IllegalArgumentException("Column name cannot be empty");
+ }
+ if (numPartitions <= 0) {
+ throw new IllegalArgumentException("Number of partitions must be greater than 0");
+ }
+ if (numSubVectors != null && numSubVectors <= 0) {
+ throw new IllegalArgumentException("Number of sub-vectors must be greater than 0");
+ }
+ if (numBits <= 0 || numBits > 16) {
+ throw new IllegalArgumentException("Quantization bits must be between 1 and 16");
+ }
+ }
+ }
+
+ /** Index build result */
+ public static class IndexBuildResult implements Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private final boolean success;
+ private final LanceOptions.IndexType indexType;
+ private final String columnName;
+ private final String datasetPath;
+ private final long durationMillis;
+ private final String errorMessage;
+
+ public IndexBuildResult(
+ boolean success,
+ LanceOptions.IndexType indexType,
+ String columnName,
+ String datasetPath,
+ long durationMillis,
+ String errorMessage) {
+ this.success = success;
+ this.indexType = indexType;
+ this.columnName = columnName;
+ this.datasetPath = datasetPath;
+ this.durationMillis = durationMillis;
+ this.errorMessage = errorMessage;
+ }
+
+ public boolean isSuccess() {
+ return success;
+ }
+
+ public LanceOptions.IndexType getIndexType() {
+ return indexType;
}
- /**
- * Create builder
- */
- public static Builder builder() {
- return new Builder();
+ public String getColumnName() {
+ return columnName;
}
- /**
- * Create index builder from LanceOptions
- */
- public static LanceIndexBuilder fromOptions(LanceOptions options) {
- return builder()
- .datasetPath(options.getPath())
- .columnName(options.getIndexColumn())
- .indexType(options.getIndexType())
- .metricType(options.getVectorMetric())
- .numPartitions(options.getIndexNumPartitions())
- .numSubVectors(options.getIndexNumSubVectors())
- .numBits(options.getIndexNumBits())
- .maxLevel(options.getIndexMaxLevel())
- .m(options.getIndexM())
- .efConstruction(options.getIndexEfConstruction())
- .build();
+ public String getDatasetPath() {
+ return datasetPath;
}
- /**
- * Builder
- */
- public static class Builder {
- private String datasetPath;
- private String columnName;
- private LanceOptions.IndexType indexType = LanceOptions.IndexType.IVF_PQ;
- private LanceOptions.MetricType metricType = LanceOptions.MetricType.L2;
- private int numPartitions = 256;
- private Integer numSubVectors;
- private int numBits = 8;
- private int maxLevel = 7;
- private int m = 16;
- private int efConstruction = 100;
- private boolean replace = false;
-
- public Builder datasetPath(String datasetPath) {
- this.datasetPath = datasetPath;
- return this;
- }
-
- public Builder columnName(String columnName) {
- this.columnName = columnName;
- return this;
- }
-
- public Builder indexType(LanceOptions.IndexType indexType) {
- this.indexType = indexType;
- return this;
- }
-
- public Builder metricType(LanceOptions.MetricType metricType) {
- this.metricType = metricType;
- return this;
- }
-
- public Builder numPartitions(int numPartitions) {
- this.numPartitions = numPartitions;
- return this;
- }
-
- public Builder numSubVectors(Integer numSubVectors) {
- this.numSubVectors = numSubVectors;
- return this;
- }
-
- public Builder numBits(int numBits) {
- this.numBits = numBits;
- return this;
- }
-
- public Builder maxLevel(int maxLevel) {
- this.maxLevel = maxLevel;
- return this;
- }
-
- public Builder m(int m) {
- this.m = m;
- return this;
- }
-
- public Builder efConstruction(int efConstruction) {
- this.efConstruction = efConstruction;
- return this;
- }
-
- public Builder replace(boolean replace) {
- this.replace = replace;
- return this;
- }
-
- public LanceIndexBuilder build() {
- validate();
- return new LanceIndexBuilder(this);
- }
-
- private void validate() {
- if (datasetPath == null || datasetPath.isEmpty()) {
- throw new IllegalArgumentException("Dataset path cannot be empty");
- }
- if (columnName == null || columnName.isEmpty()) {
- throw new IllegalArgumentException("Column name cannot be empty");
- }
- if (numPartitions <= 0) {
- throw new IllegalArgumentException("Number of partitions must be greater than 0");
- }
- if (numSubVectors != null && numSubVectors <= 0) {
- throw new IllegalArgumentException("Number of sub-vectors must be greater than 0");
- }
- if (numBits <= 0 || numBits > 16) {
- throw new IllegalArgumentException("Quantization bits must be between 1 and 16");
- }
- }
+ public long getDurationMillis() {
+ return durationMillis;
}
- /**
- * Index build result
- */
- public static class IndexBuildResult implements Serializable {
- private static final long serialVersionUID = 1L;
-
- private final boolean success;
- private final LanceOptions.IndexType indexType;
- private final String columnName;
- private final String datasetPath;
- private final long durationMillis;
- private final String errorMessage;
-
- public IndexBuildResult(boolean success, LanceOptions.IndexType indexType, String columnName,
- String datasetPath, long durationMillis, String errorMessage) {
- this.success = success;
- this.indexType = indexType;
- this.columnName = columnName;
- this.datasetPath = datasetPath;
- this.durationMillis = durationMillis;
- this.errorMessage = errorMessage;
- }
-
- public boolean isSuccess() {
- return success;
- }
-
- public LanceOptions.IndexType getIndexType() {
- return indexType;
- }
-
- public String getColumnName() {
- return columnName;
- }
-
- public String getDatasetPath() {
- return datasetPath;
- }
-
- public long getDurationMillis() {
- return durationMillis;
- }
-
- public String getErrorMessage() {
- return errorMessage;
- }
-
- @Override
- public String toString() {
- return "IndexBuildResult{" +
- "success=" + success +
- ", indexType=" + indexType +
- ", columnName='" + columnName + '\'' +
- ", datasetPath='" + datasetPath + '\'' +
- ", durationMillis=" + durationMillis +
- ", errorMessage='" + errorMessage + '\'' +
- '}';
- }
+ public String getErrorMessage() {
+ return errorMessage;
+ }
+
+ @Override
+ public String toString() {
+ return "IndexBuildResult{"
+ + "success="
+ + success
+ + ", indexType="
+ + indexType
+ + ", columnName='"
+ + columnName
+ + '\''
+ + ", datasetPath='"
+ + datasetPath
+ + '\''
+ + ", durationMillis="
+ + durationMillis
+ + ", errorMessage='"
+ + errorMessage
+ + '\''
+ + '}';
}
+ }
}
diff --git a/src/main/java/org/apache/flink/connector/lance/LanceInputFormat.java b/src/main/java/org/apache/flink/connector/lance/LanceInputFormat.java
index 2785a00..be58ce3 100644
--- a/src/main/java/org/apache/flink/connector/lance/LanceInputFormat.java
+++ b/src/main/java/org/apache/flink/connector/lance/LanceInputFormat.java
@@ -1,11 +1,7 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -15,7 +11,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.flink.connector.lance;
import org.apache.flink.api.common.io.RichInputFormat;
@@ -41,7 +36,6 @@
import org.slf4j.LoggerFactory;
import java.io.IOException;
-import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
@@ -49,287 +43,279 @@
/**
* Lance InputFormat implementation.
- *
- * Reads data from Lance dataset using InputFormat interface, supports parallel reading with splits.
+ *
+ *
Reads data from Lance dataset using InputFormat interface, supports parallel reading with
+ * splits.
*/
public class LanceInputFormat extends RichInputFormat {
- private static final long serialVersionUID = 1L;
- private static final Logger LOG = LoggerFactory.getLogger(LanceInputFormat.class);
-
- private final LanceOptions options;
- private final RowType rowType;
- private final String[] selectedColumns;
-
- private transient BufferAllocator allocator;
- private transient Dataset dataset;
- private transient RowDataConverter converter;
- private transient LanceScanner currentScanner;
- private transient ArrowReader currentReader;
- private transient Iterator currentBatchIterator;
- private transient boolean reachedEnd;
-
- /**
- * Create LanceInputFormat
- *
- * @param options Lance configuration options
- * @param rowType Flink RowType
- */
- public LanceInputFormat(LanceOptions options, RowType rowType) {
- this.options = options;
- this.rowType = rowType;
-
- List columns = options.getReadColumns();
- this.selectedColumns = columns != null && !columns.isEmpty()
- ? columns.toArray(new String[0])
- : null;
+ private static final long serialVersionUID = 1L;
+ private static final Logger LOG = LoggerFactory.getLogger(LanceInputFormat.class);
+
+ private final LanceOptions options;
+ private final RowType rowType;
+ private final String[] selectedColumns;
+
+ private transient BufferAllocator allocator;
+ private transient Dataset dataset;
+ private transient RowDataConverter converter;
+ private transient LanceScanner currentScanner;
+ private transient ArrowReader currentReader;
+ private transient Iterator currentBatchIterator;
+ private transient boolean reachedEnd;
+
+ /**
+ * Create LanceInputFormat
+ *
+ * @param options Lance configuration options
+ * @param rowType Flink RowType
+ */
+ public LanceInputFormat(LanceOptions options, RowType rowType) {
+ this.options = options;
+ this.rowType = rowType;
+
+ List columns = options.getReadColumns();
+ this.selectedColumns =
+ columns != null && !columns.isEmpty() ? columns.toArray(new String[0]) : null;
+ }
+
+ @Override
+ public void configure(Configuration parameters) {
+ // Configuration already done in constructor
+ }
+
+ @Override
+ public BaseStatistics getStatistics(BaseStatistics cachedStatistics) throws IOException {
+ // Return basic statistics
+ return cachedStatistics;
+ }
+
+ @Override
+ public LanceSplit[] createInputSplits(int minNumSplits) throws IOException {
+ LOG.info("Creating input splits, minimum split count: {}", minNumSplits);
+
+ String datasetPath = options.getPath();
+ if (datasetPath == null || datasetPath.isEmpty()) {
+ throw new IOException("Dataset path cannot be empty");
}
- @Override
- public void configure(Configuration parameters) {
- // Configuration already done in constructor
+ BufferAllocator tempAllocator = new RootAllocator(Long.MAX_VALUE);
+ try {
+ Dataset tempDataset = Dataset.open(datasetPath, tempAllocator);
+ try {
+ List fragments = tempDataset.getFragments();
+ LanceSplit[] splits = new LanceSplit[fragments.size()];
+
+ for (int i = 0; i < fragments.size(); i++) {
+ Fragment fragment = fragments.get(i);
+ long rowCount = fragment.countRows();
+ splits[i] = new LanceSplit(i, fragment.getId(), datasetPath, rowCount);
+ }
+
+ LOG.info("Created {} input splits", splits.length);
+ return splits;
+ } finally {
+ tempDataset.close();
+ }
+ } finally {
+ tempAllocator.close();
}
+ }
- @Override
- public BaseStatistics getStatistics(BaseStatistics cachedStatistics) throws IOException {
- // Return basic statistics
- return cachedStatistics;
+ @Override
+ public InputSplitAssigner getInputSplitAssigner(LanceSplit[] inputSplits) {
+ return new LanceSplitAssigner(inputSplits);
+ }
+
+ @Override
+ public void open(LanceSplit split) throws IOException {
+ LOG.info("Opening split: {}", split);
+
+ this.allocator = new RootAllocator(Long.MAX_VALUE);
+ this.reachedEnd = false;
+
+ // Open dataset
+ String datasetPath = split.getDatasetPath();
+ try {
+ this.dataset = Dataset.open(datasetPath, allocator);
+ } catch (Exception e) {
+ throw new IOException("Cannot open dataset: " + datasetPath, e);
}
- @Override
- public LanceSplit[] createInputSplits(int minNumSplits) throws IOException {
- LOG.info("Creating input splits, minimum split count: {}", minNumSplits);
-
- String datasetPath = options.getPath();
- if (datasetPath == null || datasetPath.isEmpty()) {
- throw new IOException("Dataset path cannot be empty");
- }
+ // Initialize converter
+ RowType actualRowType = this.rowType;
+ if (actualRowType == null) {
+ Schema arrowSchema = dataset.getSchema();
+ actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
+ }
+ this.converter = new RowDataConverter(actualRowType);
- BufferAllocator tempAllocator = new RootAllocator(Long.MAX_VALUE);
- try {
- Dataset tempDataset = Dataset.open(datasetPath, tempAllocator);
- try {
- List fragments = tempDataset.getFragments();
- LanceSplit[] splits = new LanceSplit[fragments.size()];
-
- for (int i = 0; i < fragments.size(); i++) {
- Fragment fragment = fragments.get(i);
- long rowCount = fragment.countRows();
- splits[i] = new LanceSplit(i, fragment.getId(), datasetPath, rowCount);
- }
-
- LOG.info("Created {} input splits", splits.length);
- return splits;
- } finally {
- tempDataset.close();
- }
- } finally {
- tempAllocator.close();
- }
+ // Get specified Fragment
+ List fragments = dataset.getFragments();
+ Fragment targetFragment = null;
+ for (Fragment fragment : fragments) {
+ if (fragment.getId() == split.getFragmentId()) {
+ targetFragment = fragment;
+ break;
+ }
}
- @Override
- public InputSplitAssigner getInputSplitAssigner(LanceSplit[] inputSplits) {
- return new LanceSplitAssigner(inputSplits);
+ if (targetFragment == null) {
+ throw new IOException("Cannot find Fragment: " + split.getFragmentId());
}
- @Override
- public void open(LanceSplit split) throws IOException {
- LOG.info("Opening split: {}", split);
-
- this.allocator = new RootAllocator(Long.MAX_VALUE);
- this.reachedEnd = false;
-
- // Open dataset
- String datasetPath = split.getDatasetPath();
- try {
- this.dataset = Dataset.open(datasetPath, allocator);
- } catch (Exception e) {
- throw new IOException("Cannot open dataset: " + datasetPath, e);
- }
-
- // Initialize converter
- RowType actualRowType = this.rowType;
- if (actualRowType == null) {
- Schema arrowSchema = dataset.getSchema();
- actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
- }
- this.converter = new RowDataConverter(actualRowType);
-
- // Get specified Fragment
- List fragments = dataset.getFragments();
- Fragment targetFragment = null;
- for (Fragment fragment : fragments) {
- if (fragment.getId() == split.getFragmentId()) {
- targetFragment = fragment;
- break;
- }
- }
-
- if (targetFragment == null) {
- throw new IOException("Cannot find Fragment: " + split.getFragmentId());
- }
-
- // Build scan options
- ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
- scanOptionsBuilder.batchSize(options.getReadBatchSize());
-
- if (selectedColumns != null && selectedColumns.length > 0) {
- scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
- }
-
- String filter = options.getReadFilter();
- if (filter != null && !filter.isEmpty()) {
- scanOptionsBuilder.filter(filter);
- }
-
- ScanOptions scanOptions = scanOptionsBuilder.build();
-
- // Create Scanner
- try {
- this.currentScanner = targetFragment.newScan(scanOptions);
- this.currentReader = currentScanner.scanBatches();
- } catch (Exception e) {
- throw new IOException("Failed to create Scanner", e);
- }
-
- // Load first batch of data
- loadNextBatch();
+ // Build scan options
+ ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
+ scanOptionsBuilder.batchSize(options.getReadBatchSize());
+
+ if (selectedColumns != null && selectedColumns.length > 0) {
+ scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
}
- /**
- * Load next batch of data
- */
- private void loadNextBatch() throws IOException {
- try {
- if (currentReader.loadNextBatch()) {
- VectorSchemaRoot root = currentReader.getVectorSchemaRoot();
- List rows = converter.toRowDataList(root);
- this.currentBatchIterator = rows.iterator();
- } else {
- this.reachedEnd = true;
- this.currentBatchIterator = null;
- }
- } catch (Exception e) {
- throw new IOException("Failed to load data batch", e);
- }
+ String filter = options.getReadFilter();
+ if (filter != null && !filter.isEmpty()) {
+ scanOptionsBuilder.filter(filter);
}
- @Override
- public boolean reachedEnd() throws IOException {
- return reachedEnd;
+ ScanOptions scanOptions = scanOptionsBuilder.build();
+
+ // Create Scanner
+ try {
+ this.currentScanner = targetFragment.newScan(scanOptions);
+ this.currentReader = currentScanner.scanBatches();
+ } catch (Exception e) {
+ throw new IOException("Failed to create Scanner", e);
}
- @Override
- public RowData nextRecord(RowData reuse) throws IOException {
- if (reachedEnd) {
- return null;
- }
-
- // Current batch still has data
- if (currentBatchIterator != null && currentBatchIterator.hasNext()) {
- return currentBatchIterator.next();
- }
-
- // Load next batch
- loadNextBatch();
-
- if (reachedEnd) {
- return null;
- }
-
- if (currentBatchIterator != null && currentBatchIterator.hasNext()) {
- return currentBatchIterator.next();
- }
-
- return null;
+ // Load first batch of data
+ loadNextBatch();
+ }
+
+ /** Load next batch of data */
+ private void loadNextBatch() throws IOException {
+ try {
+ if (currentReader.loadNextBatch()) {
+ VectorSchemaRoot root = currentReader.getVectorSchemaRoot();
+ List rows = converter.toRowDataList(root);
+ this.currentBatchIterator = rows.iterator();
+ } else {
+ this.reachedEnd = true;
+ this.currentBatchIterator = null;
+ }
+ } catch (Exception e) {
+ throw new IOException("Failed to load data batch", e);
}
+ }
- @Override
- public void close() throws IOException {
- LOG.info("Closing LanceInputFormat");
-
- if (currentReader != null) {
- try {
- currentReader.close();
- } catch (Exception e) {
- LOG.warn("Failed to close Reader", e);
- }
- currentReader = null;
- }
-
- if (currentScanner != null) {
- try {
- currentScanner.close();
- } catch (Exception e) {
- LOG.warn("Failed to close Scanner", e);
- }
- currentScanner = null;
- }
-
- if (dataset != null) {
- try {
- dataset.close();
- } catch (Exception e) {
- LOG.warn("Failed to close dataset", e);
- }
- dataset = null;
- }
-
- if (allocator != null) {
- try {
- allocator.close();
- } catch (Exception e) {
- LOG.warn("Failed to close allocator", e);
- }
- allocator = null;
- }
+ @Override
+ public boolean reachedEnd() throws IOException {
+ return reachedEnd;
+ }
+
+ @Override
+ public RowData nextRecord(RowData reuse) throws IOException {
+ if (reachedEnd) {
+ return null;
}
- /**
- * Get RowType
- */
- public RowType getRowType() {
- return rowType;
+ // Current batch still has data
+ if (currentBatchIterator != null && currentBatchIterator.hasNext()) {
+ return currentBatchIterator.next();
}
- /**
- * Get configuration options
- */
- public LanceOptions getOptions() {
- return options;
+ // Load next batch
+ loadNextBatch();
+
+ if (reachedEnd) {
+ return null;
}
- /**
- * Lance split assigner
- */
- private static class LanceSplitAssigner implements InputSplitAssigner {
- private final List remainingSplits;
-
- public LanceSplitAssigner(LanceSplit[] splits) {
- this.remainingSplits = new ArrayList<>();
- for (LanceSplit split : splits) {
- remainingSplits.add(split);
- }
- }
+ if (currentBatchIterator != null && currentBatchIterator.hasNext()) {
+ return currentBatchIterator.next();
+ }
- @Override
- public synchronized LanceSplit getNextInputSplit(String host, int taskId) {
- if (remainingSplits.isEmpty()) {
- return null;
- }
- return remainingSplits.remove(remainingSplits.size() - 1);
- }
+ return null;
+ }
- @Override
- public void returnInputSplit(List splits, int taskId) {
- for (org.apache.flink.core.io.InputSplit split : splits) {
- if (split instanceof LanceSplit) {
- synchronized (this) {
- remainingSplits.add((LanceSplit) split);
- }
- }
- }
+ @Override
+ public void close() throws IOException {
+ LOG.info("Closing LanceInputFormat");
+
+ if (currentReader != null) {
+ try {
+ currentReader.close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close Reader", e);
+ }
+ currentReader = null;
+ }
+
+ if (currentScanner != null) {
+ try {
+ currentScanner.close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close Scanner", e);
+ }
+ currentScanner = null;
+ }
+
+ if (dataset != null) {
+ try {
+ dataset.close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close dataset", e);
+ }
+ dataset = null;
+ }
+
+ if (allocator != null) {
+ try {
+ allocator.close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close allocator", e);
+ }
+ allocator = null;
+ }
+ }
+
+ /** Get RowType */
+ public RowType getRowType() {
+ return rowType;
+ }
+
+ /** Get configuration options */
+ public LanceOptions getOptions() {
+ return options;
+ }
+
+ /** Lance split assigner */
+ private static class LanceSplitAssigner implements InputSplitAssigner {
+ private final List remainingSplits;
+
+ LanceSplitAssigner(LanceSplit[] splits) {
+ this.remainingSplits = new ArrayList<>();
+ for (LanceSplit split : splits) {
+ remainingSplits.add(split);
+ }
+ }
+
+ @Override
+ public synchronized LanceSplit getNextInputSplit(String host, int taskId) {
+ if (remainingSplits.isEmpty()) {
+ return null;
+ }
+ return remainingSplits.remove(remainingSplits.size() - 1);
+ }
+
+ @Override
+ public void returnInputSplit(List splits, int taskId) {
+ for (org.apache.flink.core.io.InputSplit split : splits) {
+ if (split instanceof LanceSplit) {
+ synchronized (this) {
+ remainingSplits.add((LanceSplit) split);
+ }
}
+ }
}
+ }
}
diff --git a/src/main/java/org/apache/flink/connector/lance/LanceSink.java b/src/main/java/org/apache/flink/connector/lance/LanceSink.java
index feeec90..842b990 100644
--- a/src/main/java/org/apache/flink/connector/lance/LanceSink.java
+++ b/src/main/java/org/apache/flink/connector/lance/LanceSink.java
@@ -1,11 +1,7 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -15,7 +11,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.flink.connector.lance;
import org.apache.flink.configuration.Configuration;
@@ -52,294 +47,282 @@
/**
* Lance Sink implementation.
- *
+ *
* Writes Flink RowData to Lance dataset, supports batch writing and Checkpoint.
- *
+ *
*
Usage example:
+ *
*
{@code
* LanceOptions options = LanceOptions.builder()
* .path("/path/to/lance/dataset")
* .writeBatchSize(1024)
* .writeMode(WriteMode.APPEND)
* .build();
- *
+ *
* LanceSink sink = new LanceSink(options, rowType);
* dataStream.addSink(sink);
* }
*/
public class LanceSink extends RichSinkFunction implements CheckpointedFunction {
- private static final long serialVersionUID = 1L;
- private static final Logger LOG = LoggerFactory.getLogger(LanceSink.class);
-
- private final LanceOptions options;
- private final RowType rowType;
-
- private transient BufferAllocator allocator;
- private transient Dataset dataset;
- private transient RowDataConverter converter;
- private transient Schema arrowSchema;
- private transient List buffer;
- private transient long totalWrittenRows;
- private transient boolean datasetExists;
- private transient boolean isFirstWrite;
-
- /**
- * Create LanceSink
- *
- * @param options Lance configuration options
- * @param rowType Flink RowType
- */
- public LanceSink(LanceOptions options, RowType rowType) {
- this.options = options;
- this.rowType = rowType;
+ private static final long serialVersionUID = 1L;
+ private static final Logger LOG = LoggerFactory.getLogger(LanceSink.class);
+
+ private final LanceOptions options;
+ private final RowType rowType;
+
+ private transient BufferAllocator allocator;
+ private transient Dataset dataset;
+ private transient RowDataConverter converter;
+ private transient Schema arrowSchema;
+ private transient List buffer;
+ private transient long totalWrittenRows;
+ private transient boolean datasetExists;
+ private transient boolean isFirstWrite;
+
+ /**
+ * Create LanceSink
+ *
+ * @param options Lance configuration options
+ * @param rowType Flink RowType
+ */
+ public LanceSink(LanceOptions options, RowType rowType) {
+ this.options = options;
+ this.rowType = rowType;
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+
+ LOG.info("Opening Lance Sink: {}", options.getPath());
+
+ this.allocator = new RootAllocator(Long.MAX_VALUE);
+ this.buffer = new ArrayList<>(options.getWriteBatchSize());
+ this.totalWrittenRows = 0;
+ this.isFirstWrite = true;
+
+ // Initialize converter and Schema
+ this.converter = new RowDataConverter(rowType);
+ this.arrowSchema = LanceTypeConverter.toArrowSchema(rowType);
+
+ // Check if dataset exists
+ String datasetPath = options.getPath();
+ if (datasetPath == null || datasetPath.isEmpty()) {
+ throw new IllegalArgumentException("Lance dataset path cannot be empty");
}
- @Override
- public void open(Configuration parameters) throws Exception {
- super.open(parameters);
-
- LOG.info("Opening Lance Sink: {}", options.getPath());
-
- this.allocator = new RootAllocator(Long.MAX_VALUE);
- this.buffer = new ArrayList<>(options.getWriteBatchSize());
- this.totalWrittenRows = 0;
- this.isFirstWrite = true;
-
- // Initialize converter and Schema
- this.converter = new RowDataConverter(rowType);
- this.arrowSchema = LanceTypeConverter.toArrowSchema(rowType);
-
- // Check if dataset exists
- String datasetPath = options.getPath();
- if (datasetPath == null || datasetPath.isEmpty()) {
- throw new IllegalArgumentException("Lance dataset path cannot be empty");
- }
-
- Path path = Paths.get(datasetPath);
- this.datasetExists = Files.exists(path);
-
- // If overwrite mode and dataset exists, delete first
- if (datasetExists && options.getWriteMode() == LanceOptions.WriteMode.OVERWRITE) {
- LOG.info("Overwrite mode, deleting existing dataset: {}", datasetPath);
- deleteDirectory(path);
- this.datasetExists = false;
- }
-
- LOG.info("Lance Sink opened, Schema: {}", rowType);
+ Path path = Paths.get(datasetPath);
+ this.datasetExists = Files.exists(path);
+
+ // If overwrite mode and dataset exists, delete first
+ if (datasetExists && options.getWriteMode() == LanceOptions.WriteMode.OVERWRITE) {
+ LOG.info("Overwrite mode, deleting existing dataset: {}", datasetPath);
+ deleteDirectory(path);
+ this.datasetExists = false;
}
- @Override
- public void invoke(RowData value, Context context) throws Exception {
- buffer.add(value);
-
- // When buffer reaches batch size, execute write
- if (buffer.size() >= options.getWriteBatchSize()) {
- flush();
- }
+ LOG.info("Lance Sink opened, Schema: {}", rowType);
+ }
+
+ @Override
+ public void invoke(RowData value, Context context) throws Exception {
+ buffer.add(value);
+
+ // When buffer reaches batch size, execute write
+ if (buffer.size() >= options.getWriteBatchSize()) {
+ flush();
}
+ }
- /**
- * Flush buffer, write data to Lance dataset
- */
- public void flush() throws IOException {
- if (buffer.isEmpty()) {
- return;
- }
-
- LOG.debug("Flushing buffer, row count: {}", buffer.size());
-
- try (VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, allocator)) {
- // Convert RowData to VectorSchemaRoot
- converter.toVectorSchemaRoot(buffer, root);
-
- String datasetPath = options.getPath();
-
- // Build write parameters
- WriteParams writeParams = new WriteParams.Builder()
- .withMaxRowsPerFile(options.getWriteMaxRowsPerFile())
- .build();
-
- // Create Fragment
- List fragments = Fragment.create(
- datasetPath,
- allocator,
- root,
- writeParams
- );
-
- if (!datasetExists) {
- // Create new dataset (using Overwrite operation)
- FragmentOperation.Overwrite overwrite = new FragmentOperation.Overwrite(fragments, arrowSchema);
- dataset = overwrite.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap());
- datasetExists = true;
- isFirstWrite = false;
- LOG.info("Created new dataset: {}", datasetPath);
- } else {
- // Append data
- if (isFirstWrite && options.getWriteMode() == LanceOptions.WriteMode.OVERWRITE) {
- // First write and overwrite mode
- FragmentOperation.Overwrite overwrite = new FragmentOperation.Overwrite(fragments, arrowSchema);
- dataset = overwrite.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap());
- isFirstWrite = false;
- } else {
- // Append mode
- FragmentOperation.Append append = new FragmentOperation.Append(fragments);
- dataset = append.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap());
- }
- }
-
- totalWrittenRows += buffer.size();
- LOG.debug("Written {} rows, total: {} rows", buffer.size(), totalWrittenRows);
-
- buffer.clear();
- } catch (Exception e) {
- throw new IOException("Failed to write Lance dataset", e);
- }
+ /** Flush buffer, write data to Lance dataset */
+ public void flush() throws IOException {
+ if (buffer.isEmpty()) {
+ return;
}
- @Override
- public void close() throws Exception {
- LOG.info("Closing Lance Sink");
- // Flush remaining data
- try {
- flush();
- } catch (Exception e) {
- LOG.warn("Failed to flush data on close", e);
- }
- if (dataset != null) {
- try {
- dataset.close();
- } catch (Exception e) {
- LOG.warn("Failed to close dataset", e);
- }
- dataset = null;
- }
-
- if (allocator != null) {
- try {
- allocator.close();
- } catch (Exception e) {
- LOG.warn("Failed to close allocator", e);
- }
- allocator = null;
+ LOG.debug("Flushing buffer, row count: {}", buffer.size());
+
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, allocator)) {
+ // Convert RowData to VectorSchemaRoot
+ converter.toVectorSchemaRoot(buffer, root);
+
+ String datasetPath = options.getPath();
+
+ // Build write parameters
+ WriteParams writeParams =
+ new WriteParams.Builder().withMaxRowsPerFile(options.getWriteMaxRowsPerFile()).build();
+
+ // Create Fragment
+ List fragments = Fragment.create(datasetPath, allocator, root, writeParams);
+
+ if (!datasetExists) {
+ // Create new dataset (using Overwrite operation)
+ FragmentOperation.Overwrite overwrite =
+ new FragmentOperation.Overwrite(fragments, arrowSchema);
+ dataset =
+ overwrite.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap());
+ datasetExists = true;
+ isFirstWrite = false;
+ LOG.info("Created new dataset: {}", datasetPath);
+ } else {
+ // Append data
+ if (isFirstWrite && options.getWriteMode() == LanceOptions.WriteMode.OVERWRITE) {
+ // First write and overwrite mode
+ FragmentOperation.Overwrite overwrite =
+ new FragmentOperation.Overwrite(fragments, arrowSchema);
+ dataset =
+ overwrite.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap());
+ isFirstWrite = false;
+ } else {
+ // Append mode
+ FragmentOperation.Append append = new FragmentOperation.Append(fragments);
+ dataset = append.commit(allocator, datasetPath, Optional.empty(), Collections.emptyMap());
}
-
- LOG.info("Lance Sink closed, total written {} rows", totalWrittenRows);
-
- super.close();
- }
+ }
- @Override
- public void snapshotState(FunctionSnapshotContext context) throws Exception {
- LOG.debug("Snapshot state, checkpointId: {}", context.getCheckpointId());
-
- // Flush all buffered data at Checkpoint
- flush();
- }
+ totalWrittenRows += buffer.size();
+ LOG.debug("Written {} rows, total: {} rows", buffer.size(), totalWrittenRows);
- @Override
- public void initializeState(FunctionInitializationContext context) throws Exception {
- LOG.debug("Initialize state, isRestored: {}", context.isRestored());
- // State initialization (if recovery needed)
+ buffer.clear();
+ } catch (Exception e) {
+ throw new IOException("Failed to write Lance dataset", e);
}
+ }
- /**
- * Get RowType
- */
- public RowType getRowType() {
- return rowType;
+ @Override
+ public void close() throws Exception {
+ LOG.info("Closing Lance Sink");
+ // Flush remaining data
+ try {
+ flush();
+ } catch (Exception e) {
+ LOG.warn("Failed to flush data on close", e);
}
-
- /**
- * Get configuration options
- */
- public LanceOptions getOptions() {
- return options;
+ if (dataset != null) {
+ try {
+ dataset.close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close dataset", e);
+ }
+ dataset = null;
}
- /**
- * Get total written row count
- */
- public long getTotalWrittenRows() {
- return totalWrittenRows;
+ if (allocator != null) {
+ try {
+ allocator.close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close allocator", e);
+ }
+ allocator = null;
}
- /**
- * Recursively delete directory
- */
- private void deleteDirectory(Path path) throws IOException {
- if (Files.isDirectory(path)) {
- Files.list(path).forEach(child -> {
+ LOG.info("Lance Sink closed, total written {} rows", totalWrittenRows);
+
+ super.close();
+ }
+
+ @Override
+ public void snapshotState(FunctionSnapshotContext context) throws Exception {
+ LOG.debug("Snapshot state, checkpointId: {}", context.getCheckpointId());
+
+ // Flush all buffered data at Checkpoint
+ flush();
+ }
+
+ @Override
+ public void initializeState(FunctionInitializationContext context) throws Exception {
+ LOG.debug("Initialize state, isRestored: {}", context.isRestored());
+ // State initialization (if recovery needed)
+ }
+
+ /** Get RowType */
+ public RowType getRowType() {
+ return rowType;
+ }
+
+ /** Get configuration options */
+ public LanceOptions getOptions() {
+ return options;
+ }
+
+ /** Get total written row count */
+ public long getTotalWrittenRows() {
+ return totalWrittenRows;
+ }
+
+ /** Recursively delete directory */
+ private void deleteDirectory(Path path) throws IOException {
+ if (Files.isDirectory(path)) {
+ Files.list(path)
+ .forEach(
+ child -> {
try {
- deleteDirectory(child);
+ deleteDirectory(child);
} catch (IOException e) {
- LOG.warn("Failed to delete file: {}", child, e);
+ LOG.warn("Failed to delete file: {}", child, e);
}
- });
- }
- Files.deleteIfExists(path);
+ });
}
+ Files.deleteIfExists(path);
+ }
+
+ /** Builder pattern constructor */
+ public static Builder builder() {
+ return new Builder();
+ }
- /**
- * Builder pattern constructor
- */
- public static Builder builder() {
- return new Builder();
+ /** LanceSink Builder */
+ public static class Builder {
+ private String path;
+ private int batchSize = 1024;
+ private LanceOptions.WriteMode writeMode = LanceOptions.WriteMode.APPEND;
+ private int maxRowsPerFile = 1000000;
+ private RowType rowType;
+
+ public Builder path(String path) {
+ this.path = path;
+ return this;
}
- /**
- * LanceSink Builder
- */
- public static class Builder {
- private String path;
- private int batchSize = 1024;
- private LanceOptions.WriteMode writeMode = LanceOptions.WriteMode.APPEND;
- private int maxRowsPerFile = 1000000;
- private RowType rowType;
-
- public Builder path(String path) {
- this.path = path;
- return this;
- }
+ public Builder batchSize(int batchSize) {
+ this.batchSize = batchSize;
+ return this;
+ }
- public Builder batchSize(int batchSize) {
- this.batchSize = batchSize;
- return this;
- }
+ public Builder writeMode(LanceOptions.WriteMode writeMode) {
+ this.writeMode = writeMode;
+ return this;
+ }
- public Builder writeMode(LanceOptions.WriteMode writeMode) {
- this.writeMode = writeMode;
- return this;
- }
+ public Builder maxRowsPerFile(int maxRowsPerFile) {
+ this.maxRowsPerFile = maxRowsPerFile;
+ return this;
+ }
- public Builder maxRowsPerFile(int maxRowsPerFile) {
- this.maxRowsPerFile = maxRowsPerFile;
- return this;
- }
+ public Builder rowType(RowType rowType) {
+ this.rowType = rowType;
+ return this;
+ }
- public Builder rowType(RowType rowType) {
- this.rowType = rowType;
- return this;
- }
+ public LanceSink build() {
+ if (path == null || path.isEmpty()) {
+ throw new IllegalArgumentException("Dataset path cannot be empty");
+ }
- public LanceSink build() {
- if (path == null || path.isEmpty()) {
- throw new IllegalArgumentException("Dataset path cannot be empty");
- }
-
- if (rowType == null) {
- throw new IllegalArgumentException("RowType cannot be null");
- }
-
- LanceOptions options = LanceOptions.builder()
- .path(path)
- .writeBatchSize(batchSize)
- .writeMode(writeMode)
- .writeMaxRowsPerFile(maxRowsPerFile)
- .build();
-
- return new LanceSink(options, rowType);
- }
+ if (rowType == null) {
+ throw new IllegalArgumentException("RowType cannot be null");
+ }
+
+ LanceOptions options =
+ LanceOptions.builder()
+ .path(path)
+ .writeBatchSize(batchSize)
+ .writeMode(writeMode)
+ .writeMaxRowsPerFile(maxRowsPerFile)
+ .build();
+
+ return new LanceSink(options, rowType);
}
+ }
}
diff --git a/src/main/java/org/apache/flink/connector/lance/LanceSource.java b/src/main/java/org/apache/flink/connector/lance/LanceSource.java
index ade00a0..6b15699 100644
--- a/src/main/java/org/apache/flink/connector/lance/LanceSource.java
+++ b/src/main/java/org/apache/flink/connector/lance/LanceSource.java
@@ -1,11 +1,7 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -15,10 +11,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.flink.connector.lance;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.connector.lance.config.LanceOptions;
import org.apache.flink.connector.lance.converter.LanceTypeConverter;
@@ -47,364 +41,354 @@
/**
* Lance data source implementation.
- *
+ *
* Reads data from Lance dataset and converts to Flink RowData.
+ *
*
Supports column pruning, predicate push-down and Limit push-down optimization.
- *
+ *
*
Usage example:
+ *
*
{@code
* LanceOptions options = LanceOptions.builder()
* .path("/path/to/lance/dataset")
* .readBatchSize(1024)
* .readLimit(100L) // Limit push-down
* .build();
- *
+ *
* LanceSource source = new LanceSource(options, rowType);
* DataStream stream = env.addSource(source);
* }
*/
public class LanceSource extends RichParallelSourceFunction {
- private static final long serialVersionUID = 1L;
- private static final Logger LOG = LoggerFactory.getLogger(LanceSource.class);
-
- private final LanceOptions options;
- private final RowType rowType;
- private final String[] selectedColumns;
- private final Long readLimit; // Added: Limit push-down
-
- private transient volatile boolean running;
- private transient BufferAllocator allocator;
- private transient Dataset dataset;
- private transient RowDataConverter converter;
- private transient long emittedCount; // Added: emitted row count
-
- /**
- * Create LanceSource
- *
- * @param options Lance configuration options
- * @param rowType Flink RowType
- */
- public LanceSource(LanceOptions options, RowType rowType) {
- this.options = options;
- this.rowType = rowType;
-
- List columns = options.getReadColumns();
- this.selectedColumns = columns != null && !columns.isEmpty()
- ? columns.toArray(new String[0])
- : null;
- this.readLimit = options.getReadLimit();
+ private static final long serialVersionUID = 1L;
+ private static final Logger LOG = LoggerFactory.getLogger(LanceSource.class);
+
+ private final LanceOptions options;
+ private final RowType rowType;
+ private final String[] selectedColumns;
+ private final Long readLimit; // Added: Limit push-down
+
+ private transient volatile boolean running;
+ private transient BufferAllocator allocator;
+ private transient Dataset dataset;
+ private transient RowDataConverter converter;
+ private transient long emittedCount; // Added: emitted row count
+
+ /**
+ * Create LanceSource
+ *
+ * @param options Lance configuration options
+ * @param rowType Flink RowType
+ */
+ public LanceSource(LanceOptions options, RowType rowType) {
+ this.options = options;
+ this.rowType = rowType;
+
+ List columns = options.getReadColumns();
+ this.selectedColumns =
+ columns != null && !columns.isEmpty() ? columns.toArray(new String[0]) : null;
+ this.readLimit = options.getReadLimit();
+ }
+
+ /**
+ * Create LanceSource (auto-infer Schema)
+ *
+ * @param options Lance configuration options
+ */
+ public LanceSource(LanceOptions options) {
+ this(options, null);
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+
+ LOG.info("Opening Lance data source: {}", options.getPath());
+ if (readLimit != null) {
+ LOG.info("Limit push-down enabled, max read rows: {}", readLimit);
}
- /**
- * Create LanceSource (auto-infer Schema)
- *
- * @param options Lance configuration options
- */
- public LanceSource(LanceOptions options) {
- this(options, null);
+ this.running = true;
+ this.emittedCount = 0;
+ this.allocator = new RootAllocator(Long.MAX_VALUE);
+
+ // Open Lance dataset
+ String datasetPath = options.getPath();
+ if (datasetPath == null || datasetPath.isEmpty()) {
+ throw new IllegalArgumentException("Lance dataset path cannot be empty");
}
- @Override
- public void open(Configuration parameters) throws Exception {
- super.open(parameters);
-
- LOG.info("Opening Lance data source: {}", options.getPath());
- if (readLimit != null) {
- LOG.info("Limit push-down enabled, max read rows: {}", readLimit);
- }
-
- this.running = true;
- this.emittedCount = 0;
- this.allocator = new RootAllocator(Long.MAX_VALUE);
-
- // Open Lance dataset
- String datasetPath = options.getPath();
- if (datasetPath == null || datasetPath.isEmpty()) {
- throw new IllegalArgumentException("Lance dataset path cannot be empty");
- }
-
- Path path = Paths.get(datasetPath);
- try {
- this.dataset = Dataset.open(path.toString(), allocator);
- } catch (Exception e) {
- throw new IOException("Cannot open Lance dataset: " + datasetPath, e);
- }
-
- // Initialize RowDataConverter
- RowType actualRowType = this.rowType;
- if (actualRowType == null) {
- // Infer RowType from dataset Schema
- Schema arrowSchema = dataset.getSchema();
- actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
- }
- this.converter = new RowDataConverter(actualRowType);
-
- LOG.info("Lance data source opened, Schema: {}", actualRowType);
+ Path path = Paths.get(datasetPath);
+ try {
+ this.dataset = Dataset.open(path.toString(), allocator);
+ } catch (Exception e) {
+ throw new IOException("Cannot open Lance dataset: " + datasetPath, e);
}
- @Override
- public void run(SourceContext ctx) throws Exception {
- LOG.info("Start reading Lance dataset: {}", options.getPath());
-
- int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
- int numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
-
- String filter = options.getReadFilter();
-
- // If filter condition exists, use Dataset level scan (only execute on first subtask to avoid duplicate data)
- if (filter != null && !filter.isEmpty()) {
- if (subtaskIndex == 0) {
- LOG.info("Using Dataset level scan (with filter condition)");
- readDatasetWithFilter(ctx);
- } else {
- LOG.info("Subtask {} skipped (only subtask 0 executes in filter mode)", subtaskIndex);
- }
- } else if (readLimit != null) {
- // With Limit, only execute on first subtask to avoid duplicate data
- if (subtaskIndex == 0) {
- LOG.info("Using Dataset level scan (with Limit)");
- readDatasetWithFilter(ctx);
- } else {
- LOG.info("Subtask {} skipped (only subtask 0 executes in Limit mode)", subtaskIndex);
- }
- } else {
- // Without filter condition and Limit, use Fragment level parallel scan
- List fragments = dataset.getFragments();
- LOG.info("Dataset has {} Fragments, current subtask {}/{}",
- fragments.size(), subtaskIndex, numSubtasks);
-
- // Assign Fragments by subtask
- for (int i = 0; i < fragments.size() && running && !isLimitReached(); i++) {
- // Simple round-robin assignment strategy
- if (i % numSubtasks != subtaskIndex) {
- continue;
- }
-
- Fragment fragment = fragments.get(i);
- readFragment(ctx, fragment);
- }
- }
-
- LOG.info("Lance data source read completed, total emitted {} rows", emittedCount);
+ // Initialize RowDataConverter
+ RowType actualRowType = this.rowType;
+ if (actualRowType == null) {
+ // Infer RowType from dataset Schema
+ Schema arrowSchema = dataset.getSchema();
+ actualRowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
}
+ this.converter = new RowDataConverter(actualRowType);
- /**
- * Use Dataset level scan (supports filter conditions and Limit)
- */
- private void readDatasetWithFilter(SourceContext ctx) throws Exception {
- // Build scan options
- ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
-
- // Set batch size
- scanOptionsBuilder.batchSize(options.getReadBatchSize());
-
- // Set column filter
- if (selectedColumns != null && selectedColumns.length > 0) {
- scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
- }
-
- // Set data filter condition
- String filter = options.getReadFilter();
- if (filter != null && !filter.isEmpty()) {
- LOG.info("Applying filter condition: {}", filter);
- scanOptionsBuilder.filter(filter);
- }
-
- ScanOptions scanOptions = scanOptionsBuilder.build();
-
- // Use Dataset level scan
- try (LanceScanner scanner = dataset.newScan(scanOptions)) {
- try (ArrowReader reader = scanner.scanBatches()) {
- while (reader.loadNextBatch() && running && !isLimitReached()) {
- VectorSchemaRoot root = reader.getVectorSchemaRoot();
-
- // Convert to RowData and output
- List rows = converter.toRowDataList(root);
- synchronized (ctx.getCheckpointLock()) {
- for (RowData row : rows) {
- if (isLimitReached()) {
- break;
- }
- ctx.collect(row);
- emittedCount++;
- }
- }
- }
- }
- }
-
- if (isLimitReached()) {
- LOG.info("Reached Limit ({}), stop reading", readLimit);
+ LOG.info("Lance data source opened, Schema: {}", actualRowType);
+ }
+
+ @Override
+ public void run(SourceContext ctx) throws Exception {
+ LOG.info("Start reading Lance dataset: {}", options.getPath());
+
+ int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+ int numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
+
+ String filter = options.getReadFilter();
+
+ // If filter condition exists, use Dataset level scan (only execute on first subtask to avoid
+ // duplicate data)
+ if (filter != null && !filter.isEmpty()) {
+ if (subtaskIndex == 0) {
+ LOG.info("Using Dataset level scan (with filter condition)");
+ readDatasetWithFilter(ctx);
+ } else {
+ LOG.info("Subtask {} skipped (only subtask 0 executes in filter mode)", subtaskIndex);
+ }
+ } else if (readLimit != null) {
+ // With Limit, only execute on first subtask to avoid duplicate data
+ if (subtaskIndex == 0) {
+ LOG.info("Using Dataset level scan (with Limit)");
+ readDatasetWithFilter(ctx);
+ } else {
+ LOG.info("Subtask {} skipped (only subtask 0 executes in Limit mode)", subtaskIndex);
+ }
+ } else {
+ // Without filter condition and Limit, use Fragment level parallel scan
+ List fragments = dataset.getFragments();
+ LOG.info(
+ "Dataset has {} Fragments, current subtask {}/{}",
+ fragments.size(),
+ subtaskIndex,
+ numSubtasks);
+
+ // Assign Fragments by subtask
+ for (int i = 0; i < fragments.size() && running && !isLimitReached(); i++) {
+ // Simple round-robin assignment strategy
+ if (i % numSubtasks != subtaskIndex) {
+ continue;
}
+
+ Fragment fragment = fragments.get(i);
+ readFragment(ctx, fragment);
+ }
}
- /**
- * Read single Fragment (without filter condition, but supports Limit)
- */
- private void readFragment(SourceContext ctx, Fragment fragment) throws Exception {
- LOG.debug("Reading Fragment: {}", fragment.getId());
-
- // Build scan options
- ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
-
- // Set batch size
- scanOptionsBuilder.batchSize(options.getReadBatchSize());
-
- // Set column filter
- if (selectedColumns != null && selectedColumns.length > 0) {
- scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
- }
-
- // Note: Fragment level scan does not use filter, filter is only supported at Dataset level
-
- ScanOptions scanOptions = scanOptionsBuilder.build();
-
- // Create Scanner and read data
- try (LanceScanner scanner = fragment.newScan(scanOptions)) {
- try (ArrowReader reader = scanner.scanBatches()) {
- while (reader.loadNextBatch() && running && !isLimitReached()) {
- VectorSchemaRoot root = reader.getVectorSchemaRoot();
-
- // Convert to RowData and output
- List rows = converter.toRowDataList(root);
- synchronized (ctx.getCheckpointLock()) {
- for (RowData row : rows) {
- if (isLimitReached()) {
- break;
- }
- ctx.collect(row);
- emittedCount++;
- }
- }
- }
+ LOG.info("Lance data source read completed, total emitted {} rows", emittedCount);
+ }
+
+ /** Use Dataset level scan (supports filter conditions and Limit) */
+ private void readDatasetWithFilter(SourceContext ctx) throws Exception {
+ // Build scan options
+ ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
+
+ // Set batch size
+ scanOptionsBuilder.batchSize(options.getReadBatchSize());
+
+ // Set column filter
+ if (selectedColumns != null && selectedColumns.length > 0) {
+ scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
+ }
+
+ // Set data filter condition
+ String filter = options.getReadFilter();
+ if (filter != null && !filter.isEmpty()) {
+ LOG.info("Applying filter condition: {}", filter);
+ scanOptionsBuilder.filter(filter);
+ }
+
+ ScanOptions scanOptions = scanOptionsBuilder.build();
+
+ // Use Dataset level scan
+ try (LanceScanner scanner = dataset.newScan(scanOptions)) {
+ try (ArrowReader reader = scanner.scanBatches()) {
+ while (reader.loadNextBatch() && running && !isLimitReached()) {
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+
+ // Convert to RowData and output
+ List rows = converter.toRowDataList(root);
+ synchronized (ctx.getCheckpointLock()) {
+ for (RowData row : rows) {
+ if (isLimitReached()) {
+ break;
+ }
+ ctx.collect(row);
+ emittedCount++;
}
+ }
}
+ }
}
- /**
- * Check if Limit has been reached
- */
- private boolean isLimitReached() {
- return readLimit != null && emittedCount >= readLimit;
+ if (isLimitReached()) {
+ LOG.info("Reached Limit ({}), stop reading", readLimit);
}
+ }
- @Override
- public void cancel() {
- LOG.info("Cancel Lance data source");
- this.running = false;
+ /** Read single Fragment (without filter condition, but supports Limit) */
+ private void readFragment(SourceContext ctx, Fragment fragment) throws Exception {
+ LOG.debug("Reading Fragment: {}", fragment.getId());
+
+ // Build scan options
+ ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder();
+
+ // Set batch size
+ scanOptionsBuilder.batchSize(options.getReadBatchSize());
+
+ // Set column filter
+ if (selectedColumns != null && selectedColumns.length > 0) {
+ scanOptionsBuilder.columns(Arrays.asList(selectedColumns));
}
- @Override
- public void close() throws Exception {
- LOG.info("Closing Lance data source");
-
- this.running = false;
-
- if (dataset != null) {
- try {
- dataset.close();
- } catch (Exception e) {
- LOG.warn("Error closing Lance dataset", e);
- }
- dataset = null;
- }
-
- if (allocator != null) {
- try {
- allocator.close();
- } catch (Exception e) {
- LOG.warn("Error closing memory allocator", e);
+ // Note: Fragment level scan does not use filter, filter is only supported at Dataset level
+
+ ScanOptions scanOptions = scanOptionsBuilder.build();
+
+ // Create Scanner and read data
+ try (LanceScanner scanner = fragment.newScan(scanOptions)) {
+ try (ArrowReader reader = scanner.scanBatches()) {
+ while (reader.loadNextBatch() && running && !isLimitReached()) {
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+
+ // Convert to RowData and output
+ List rows = converter.toRowDataList(root);
+ synchronized (ctx.getCheckpointLock()) {
+ for (RowData row : rows) {
+ if (isLimitReached()) {
+ break;
+ }
+ ctx.collect(row);
+ emittedCount++;
}
- allocator = null;
+ }
}
-
- super.close();
+ }
}
+ }
- /**
- * Get RowType
- */
- public RowType getRowType() {
- return rowType;
- }
+ /** Check if Limit has been reached */
+ private boolean isLimitReached() {
+ return readLimit != null && emittedCount >= readLimit;
+ }
- /**
- * Get configuration options
- */
- public LanceOptions getOptions() {
- return options;
- }
+ @Override
+ public void cancel() {
+ LOG.info("Cancel Lance data source");
+ this.running = false;
+ }
- /**
- * Get selected columns
- */
- public String[] getSelectedColumns() {
- return selectedColumns;
+ @Override
+ public void close() throws Exception {
+ LOG.info("Closing Lance data source");
+
+ this.running = false;
+
+ if (dataset != null) {
+ try {
+ dataset.close();
+ } catch (Exception e) {
+ LOG.warn("Error closing Lance dataset", e);
+ }
+ dataset = null;
}
- /**
- * Builder pattern constructor
- */
- public static Builder builder() {
- return new Builder();
+ if (allocator != null) {
+ try {
+ allocator.close();
+ } catch (Exception e) {
+ LOG.warn("Error closing memory allocator", e);
+ }
+ allocator = null;
}
- /**
- * LanceSource Builder
- */
- public static class Builder {
- private String path;
- private int batchSize = 1024;
- private List columns;
- private String filter;
- private Long limit; // Added
- private RowType rowType;
-
- public Builder path(String path) {
- this.path = path;
- return this;
- }
+ super.close();
+ }
- public Builder batchSize(int batchSize) {
- this.batchSize = batchSize;
- return this;
- }
+ /** Get RowType */
+ public RowType getRowType() {
+ return rowType;
+ }
- public Builder columns(List columns) {
- this.columns = columns;
- return this;
- }
+ /** Get configuration options */
+ public LanceOptions getOptions() {
+ return options;
+ }
- public Builder filter(String filter) {
- this.filter = filter;
- return this;
- }
+ /** Get selected columns */
+ public String[] getSelectedColumns() {
+ return selectedColumns;
+ }
- public Builder limit(Long limit) {
- this.limit = limit;
- return this;
- }
+ /** Builder pattern constructor */
+ public static Builder builder() {
+ return new Builder();
+ }
- public Builder rowType(RowType rowType) {
- this.rowType = rowType;
- return this;
- }
+ /** LanceSource Builder */
+ public static class Builder {
+ private String path;
+ private int batchSize = 1024;
+ private List columns;
+ private String filter;
+ private Long limit; // Added
+ private RowType rowType;
- public LanceSource build() {
- if (path == null || path.isEmpty()) {
- throw new IllegalArgumentException("Dataset path cannot be empty");
- }
+ public Builder path(String path) {
+ this.path = path;
+ return this;
+ }
- LanceOptions options = LanceOptions.builder()
- .path(path)
- .readBatchSize(batchSize)
- .readColumns(columns)
- .readFilter(filter)
- .readLimit(limit)
- .build();
+ public Builder batchSize(int batchSize) {
+ this.batchSize = batchSize;
+ return this;
+ }
- return new LanceSource(options, rowType);
- }
+ public Builder columns(List columns) {
+ this.columns = columns;
+ return this;
+ }
+
+ public Builder filter(String filter) {
+ this.filter = filter;
+ return this;
+ }
+
+ public Builder limit(Long limit) {
+ this.limit = limit;
+ return this;
+ }
+
+ public Builder rowType(RowType rowType) {
+ this.rowType = rowType;
+ return this;
+ }
+
+ public LanceSource build() {
+ if (path == null || path.isEmpty()) {
+ throw new IllegalArgumentException("Dataset path cannot be empty");
+ }
+
+ LanceOptions options =
+ LanceOptions.builder()
+ .path(path)
+ .readBatchSize(batchSize)
+ .readColumns(columns)
+ .readFilter(filter)
+ .readLimit(limit)
+ .build();
+
+ return new LanceSource(options, rowType);
}
+ }
}
diff --git a/src/main/java/org/apache/flink/connector/lance/LanceSplit.java b/src/main/java/org/apache/flink/connector/lance/LanceSplit.java
index 1aec727..2aa1d2e 100644
--- a/src/main/java/org/apache/flink/connector/lance/LanceSplit.java
+++ b/src/main/java/org/apache/flink/connector/lance/LanceSplit.java
@@ -1,11 +1,7 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -15,7 +11,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.flink.connector.lance;
import org.apache.flink.core.io.InputSplit;
@@ -25,97 +20,88 @@
/**
* Lance data split.
- *
+ *
* Represents a Fragment in Lance dataset, used for parallel data reading.
*/
public class LanceSplit implements InputSplit, Serializable {
- private static final long serialVersionUID = 1L;
-
- /**
- * Split number
- */
- private final int splitNumber;
-
- /**
- * Fragment ID
- */
- private final int fragmentId;
-
- /**
- * Dataset path
- */
- private final String datasetPath;
-
- /**
- * Row count in Fragment (estimated)
- */
- private final long rowCount;
-
- /**
- * Create LanceSplit
- *
- * @param splitNumber Split number
- * @param fragmentId Fragment ID
- * @param datasetPath Dataset path
- * @param rowCount Row count
- */
- public LanceSplit(int splitNumber, int fragmentId, String datasetPath, long rowCount) {
- this.splitNumber = splitNumber;
- this.fragmentId = fragmentId;
- this.datasetPath = datasetPath;
- this.rowCount = rowCount;
- }
-
- @Override
- public int getSplitNumber() {
- return splitNumber;
- }
-
- /**
- * Get Fragment ID
- */
- public int getFragmentId() {
- return fragmentId;
- }
-
- /**
- * Get dataset path
- */
- public String getDatasetPath() {
- return datasetPath;
- }
-
- /**
- * Get row count
- */
- public long getRowCount() {
- return rowCount;
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- LanceSplit that = (LanceSplit) o;
- return splitNumber == that.splitNumber &&
- fragmentId == that.fragmentId &&
- rowCount == that.rowCount &&
- Objects.equals(datasetPath, that.datasetPath);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(splitNumber, fragmentId, datasetPath, rowCount);
- }
-
- @Override
- public String toString() {
- return "LanceSplit{" +
- "splitNumber=" + splitNumber +
- ", fragmentId=" + fragmentId +
- ", datasetPath='" + datasetPath + '\'' +
- ", rowCount=" + rowCount +
- '}';
- }
+ private static final long serialVersionUID = 1L;
+
+ /** Split number */
+ private final int splitNumber;
+
+ /** Fragment ID */
+ private final int fragmentId;
+
+ /** Dataset path */
+ private final String datasetPath;
+
+ /** Row count in Fragment (estimated) */
+ private final long rowCount;
+
+ /**
+ * Create LanceSplit
+ *
+ * @param splitNumber Split number
+ * @param fragmentId Fragment ID
+ * @param datasetPath Dataset path
+ * @param rowCount Row count
+ */
+ public LanceSplit(int splitNumber, int fragmentId, String datasetPath, long rowCount) {
+ this.splitNumber = splitNumber;
+ this.fragmentId = fragmentId;
+ this.datasetPath = datasetPath;
+ this.rowCount = rowCount;
+ }
+
+ @Override
+ public int getSplitNumber() {
+ return splitNumber;
+ }
+
+ /** Get Fragment ID */
+ public int getFragmentId() {
+ return fragmentId;
+ }
+
+ /** Get dataset path */
+ public String getDatasetPath() {
+ return datasetPath;
+ }
+
+ /** Get row count */
+ public long getRowCount() {
+ return rowCount;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ LanceSplit that = (LanceSplit) o;
+ return splitNumber == that.splitNumber
+ && fragmentId == that.fragmentId
+ && rowCount == that.rowCount
+ && Objects.equals(datasetPath, that.datasetPath);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(splitNumber, fragmentId, datasetPath, rowCount);
+ }
+
+ @Override
+ public String toString() {
+ return "LanceSplit{"
+ + "splitNumber="
+ + splitNumber
+ + ", fragmentId="
+ + fragmentId
+ + ", datasetPath='"
+ + datasetPath
+ + '\''
+ + ", rowCount="
+ + rowCount
+ + '}';
+ }
}
diff --git a/src/main/java/org/apache/flink/connector/lance/LanceVectorSearch.java b/src/main/java/org/apache/flink/connector/lance/LanceVectorSearch.java
index faf0c5a..25018ae 100644
--- a/src/main/java/org/apache/flink/connector/lance/LanceVectorSearch.java
+++ b/src/main/java/org/apache/flink/connector/lance/LanceVectorSearch.java
@@ -1,11 +1,7 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -15,7 +11,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.flink.connector.lance;
import org.apache.flink.connector.lance.config.LanceOptions;
@@ -49,10 +44,11 @@
/**
* Lance vector search implementation.
- *
+ *
*
Supports KNN search with L2, Cosine, and Dot distance metrics.
- *
+ *
*
Usage example:
+ *
*
{@code
* LanceVectorSearch search = LanceVectorSearch.builder()
* .datasetPath("/path/to/dataset")
@@ -60,391 +56,368 @@
* .metricType(MetricType.L2)
* .nprobes(20)
* .build();
- *
+ *
* List results = search.search(queryVector, 10);
* }
*/
public class LanceVectorSearch implements Closeable, Serializable {
- private static final long serialVersionUID = 1L;
- private static final Logger LOG = LoggerFactory.getLogger(LanceVectorSearch.class);
-
- private final String datasetPath;
- private final String columnName;
- private final MetricType metricType;
- private final int nprobes;
- private final int ef;
- private final Integer refineFactor;
-
- private transient BufferAllocator allocator;
- private transient Dataset dataset;
- private transient RowType rowType;
- private transient RowDataConverter converter;
-
- private LanceVectorSearch(Builder builder) {
- this.datasetPath = builder.datasetPath;
- this.columnName = builder.columnName;
- this.metricType = builder.metricType;
- this.nprobes = builder.nprobes;
- this.ef = builder.ef;
- this.refineFactor = builder.refineFactor;
+ private static final long serialVersionUID = 1L;
+ private static final Logger LOG = LoggerFactory.getLogger(LanceVectorSearch.class);
+
+ private final String datasetPath;
+ private final String columnName;
+ private final MetricType metricType;
+ private final int nprobes;
+ private final int ef;
+ private final Integer refineFactor;
+
+ private transient BufferAllocator allocator;
+ private transient Dataset dataset;
+ private transient RowType rowType;
+ private transient RowDataConverter converter;
+
+ private LanceVectorSearch(Builder builder) {
+ this.datasetPath = builder.datasetPath;
+ this.columnName = builder.columnName;
+ this.metricType = builder.metricType;
+ this.nprobes = builder.nprobes;
+ this.ef = builder.ef;
+ this.refineFactor = builder.refineFactor;
+ }
+
+ /** Open dataset connection */
+ public void open() throws IOException {
+ LOG.info("Opening vector search, dataset: {}", datasetPath);
+
+ this.allocator = new RootAllocator(Long.MAX_VALUE);
+
+ try {
+ this.dataset = Dataset.open(datasetPath, allocator);
+
+ // Get Schema and create converter
+ Schema arrowSchema = dataset.getSchema();
+ this.rowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
+ this.converter = new RowDataConverter(rowType);
+
+ } catch (Exception e) {
+ throw new IOException("Cannot open dataset: " + datasetPath, e);
}
-
- /**
- * Open dataset connection
- */
- public void open() throws IOException {
- LOG.info("Opening vector search, dataset: {}", datasetPath);
-
- this.allocator = new RootAllocator(Long.MAX_VALUE);
-
- try {
- this.dataset = Dataset.open(datasetPath, allocator);
-
- // Get Schema and create converter
- Schema arrowSchema = dataset.getSchema();
- this.rowType = LanceTypeConverter.toFlinkRowType(arrowSchema);
- this.converter = new RowDataConverter(rowType);
-
- } catch (Exception e) {
- throw new IOException("Cannot open dataset: " + datasetPath, e);
- }
+ }
+
+ /**
+ * Execute vector search
+ *
+ * @param queryVector Query vector
+ * @param k Number of nearest neighbors to return
+ * @return List of search results
+ */
+ public List search(float[] queryVector, int k) throws IOException {
+ return search(queryVector, k, null);
+ }
+
+ /**
+ * Execute vector search (with filter condition)
+ *
+ * @param queryVector Query vector
+ * @param k Number of nearest neighbors to return
+ * @param filter Filter condition (SQL WHERE syntax)
+ * @return List of search results
+ */
+ public List search(float[] queryVector, int k, String filter) throws IOException {
+ if (dataset == null) {
+ open();
}
- /**
- * Execute vector search
- *
- * @param queryVector Query vector
- * @param k Number of nearest neighbors to return
- * @return List of search results
- */
- public List search(float[] queryVector, int k) throws IOException {
- return search(queryVector, k, null);
- }
+ LOG.debug("Executing vector search, k={}, vector dimension={}", k, queryVector.length);
- /**
- * Execute vector search (with filter condition)
- *
- * @param queryVector Query vector
- * @param k Number of nearest neighbors to return
- * @param filter Filter condition (SQL WHERE syntax)
- * @return List of search results
- */
- public List search(float[] queryVector, int k, String filter) throws IOException {
- if (dataset == null) {
- open();
- }
-
- LOG.debug("Executing vector search, k={}, vector dimension={}", k, queryVector.length);
-
- // Validate query vector
- validateQueryVector(queryVector);
-
- List results = new ArrayList<>();
-
- try {
- // Build vector query
- Query.Builder queryBuilder = new Query.Builder()
- .setColumn(columnName)
- .setKey(queryVector)
- .setK(k)
- .setNprobes(nprobes)
- .setDistanceType(toDistanceType(metricType))
- .setUseIndex(true);
-
- if (ef > 0) {
- queryBuilder.setEf(ef);
- }
-
- if (refineFactor != null && refineFactor > 0) {
- queryBuilder.setRefineFactor(refineFactor);
- }
-
- Query query = queryBuilder.build();
-
- // Build scan options
- ScanOptions.Builder scanOptionsBuilder = new ScanOptions.Builder()
- .nearest(query)
- .withRowId(true);
-
- if (filter != null && !filter.isEmpty()) {
- scanOptionsBuilder.filter(filter);
- }
-
- ScanOptions scanOptions = scanOptionsBuilder.build();
-
- // Execute search
- try (LanceScanner scanner = dataset.newScan(scanOptions)) {
- try (ArrowReader reader = scanner.scanBatches()) {
- while (reader.loadNextBatch()) {
- VectorSchemaRoot root = reader.getVectorSchemaRoot();
-
- // Convert to RowData
- List rows = converter.toRowDataList(root);
-
- // Try to get distance score (if _distance column exists)
- Float8Vector distanceVector = null;
- try {
- distanceVector = (Float8Vector) root.getVector("_distance");
- } catch (Exception e) {
- // _distance column may not exist
- }
-
- for (int i = 0; i < rows.size(); i++) {
- double distance = 0.0;
- if (distanceVector != null && !distanceVector.isNull(i)) {
- distance = distanceVector.get(i);
- }
- results.add(new SearchResult(rows.get(i), distance));
- }
- }
- }
- }
-
- LOG.debug("Search completed, returned {} results", results.size());
- return results;
-
- } catch (Exception e) {
- throw new IOException("Vector search failed", e);
- }
- }
+ // Validate query vector
+ validateQueryVector(queryVector);
- /**
- * Execute vector search (return RowData list)
- *
- * @param queryVector Query vector
- * @param k Number of nearest neighbors to return
- * @return RowData list
- */
- public List searchRowData(float[] queryVector, int k) throws IOException {
- List results = search(queryVector, k);
- List rowDataList = new ArrayList<>(results.size());
-
- for (SearchResult result : results) {
- // Append distance score to RowData
- GenericRowData rowWithDistance = new GenericRowData(rowType.getFieldCount() + 1);
- RowData originalRow = result.getRowData();
-
- for (int i = 0; i < rowType.getFieldCount(); i++) {
- rowWithDistance.setField(i, getFieldValue(originalRow, i));
- }
- rowWithDistance.setField(rowType.getFieldCount(), result.getDistance());
-
- rowDataList.add(rowWithDistance);
- }
-
- return rowDataList;
- }
+ List results = new ArrayList<>();
- /**
- * Get field value from RowData
- */
- private Object getFieldValue(RowData rowData, int index) {
- if (rowData.isNullAt(index)) {
- return null;
- }
-
- // Simplified handling, should get based on field type in practice
- if (rowData instanceof GenericRowData) {
- return ((GenericRowData) rowData).getField(index);
- }
-
- return null;
- }
+ try {
+ // Build vector query
+ Query.Builder queryBuilder =
+ new Query.Builder()
+ .setColumn(columnName)
+ .setKey(queryVector)
+ .setK(k)
+ .setNprobes(nprobes)
+ .setDistanceType(toDistanceType(metricType))
+ .setUseIndex(true);
- /**
- * Validate query vector
- */
- private void validateQueryVector(float[] queryVector) throws IOException {
- if (queryVector == null || queryVector.length == 0) {
- throw new IllegalArgumentException("Query vector cannot be empty");
- }
-
- // Check for NaN or Infinity values
- for (float value : queryVector) {
- if (Float.isNaN(value) || Float.isInfinite(value)) {
- throw new IllegalArgumentException("Query vector contains invalid values (NaN or Infinity)");
- }
- }
- }
+ if (ef > 0) {
+ queryBuilder.setEf(ef);
+ }
- /**
- * Convert distance metric type
- */
- private DistanceType toDistanceType(MetricType metricType) {
- switch (metricType) {
- case L2:
- return DistanceType.L2;
- case COSINE:
- return DistanceType.Cosine;
- case DOT:
- return DistanceType.Dot;
- default:
- return DistanceType.L2;
- }
- }
+ if (refineFactor != null && refineFactor > 0) {
+ queryBuilder.setRefineFactor(refineFactor);
+ }
- @Override
- public void close() throws IOException {
- if (dataset != null) {
+ Query query = queryBuilder.build();
+
+ // Build scan options
+ ScanOptions.Builder scanOptionsBuilder =
+ new ScanOptions.Builder().nearest(query).withRowId(true);
+
+ if (filter != null && !filter.isEmpty()) {
+ scanOptionsBuilder.filter(filter);
+ }
+
+ ScanOptions scanOptions = scanOptionsBuilder.build();
+
+ // Execute search
+ try (LanceScanner scanner = dataset.newScan(scanOptions)) {
+ try (ArrowReader reader = scanner.scanBatches()) {
+ while (reader.loadNextBatch()) {
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+
+ // Convert to RowData
+ List rows = converter.toRowDataList(root);
+
+ // Try to get distance score (if _distance column exists)
+ Float8Vector distanceVector = null;
try {
- dataset.close();
+ distanceVector = (Float8Vector) root.getVector("_distance");
} catch (Exception e) {
- LOG.warn("Failed to close dataset", e);
+ // _distance column may not exist
}
- dataset = null;
- }
-
- if (allocator != null) {
- try {
- allocator.close();
- } catch (Exception e) {
- LOG.warn("Failed to close allocator", e);
+
+ for (int i = 0; i < rows.size(); i++) {
+ double distance = 0.0;
+ if (distanceVector != null && !distanceVector.isNull(i)) {
+ distance = distanceVector.get(i);
+ }
+ results.add(new SearchResult(rows.get(i), distance));
}
- allocator = null;
+ }
}
+ }
+
+ LOG.debug("Search completed, returned {} results", results.size());
+ return results;
+
+ } catch (Exception e) {
+ throw new IOException("Vector search failed", e);
}
+ }
+
+ /**
+ * Execute vector search (return RowData list)
+ *
+ * @param queryVector Query vector
+ * @param k Number of nearest neighbors to return
+ * @return RowData list
+ */
+ public List searchRowData(float[] queryVector, int k) throws IOException {
+ List results = search(queryVector, k);
+ List rowDataList = new ArrayList<>(results.size());
+
+ for (SearchResult result : results) {
+ // Append distance score to RowData
+ GenericRowData rowWithDistance = new GenericRowData(rowType.getFieldCount() + 1);
+ RowData originalRow = result.getRowData();
+
+ for (int i = 0; i < rowType.getFieldCount(); i++) {
+ rowWithDistance.setField(i, getFieldValue(originalRow, i));
+ }
+ rowWithDistance.setField(rowType.getFieldCount(), result.getDistance());
+
+ rowDataList.add(rowWithDistance);
+ }
+
+ return rowDataList;
+ }
- /**
- * Get RowType
- */
- public RowType getRowType() {
- return rowType;
+ /** Get field value from RowData */
+ private Object getFieldValue(RowData rowData, int index) {
+ if (rowData.isNullAt(index)) {
+ return null;
}
- /**
- * Create builder
- */
- public static Builder builder() {
- return new Builder();
+ // Simplified handling, should get based on field type in practice
+ if (rowData instanceof GenericRowData) {
+ return ((GenericRowData) rowData).getField(index);
}
- /**
- * Create vector searcher from LanceOptions
- */
- public static LanceVectorSearch fromOptions(LanceOptions options) {
- return builder()
- .datasetPath(options.getPath())
- .columnName(options.getVectorColumn())
- .metricType(options.getVectorMetric())
- .nprobes(options.getVectorNprobes())
- .ef(options.getVectorEf())
- .refineFactor(options.getVectorRefineFactor())
- .build();
+ return null;
+ }
+
+ /** Validate query vector */
+ private void validateQueryVector(float[] queryVector) throws IOException {
+ if (queryVector == null || queryVector.length == 0) {
+ throw new IllegalArgumentException("Query vector cannot be empty");
}
- /**
- * Builder
- */
- public static class Builder {
- private String datasetPath;
- private String columnName;
- private MetricType metricType = MetricType.L2;
- private int nprobes = 20;
- private int ef = 100;
- private Integer refineFactor;
-
- public Builder datasetPath(String datasetPath) {
- this.datasetPath = datasetPath;
- return this;
- }
+ // Check for NaN or Infinity values
+ for (float value : queryVector) {
+ if (Float.isNaN(value) || Float.isInfinite(value)) {
+ throw new IllegalArgumentException(
+ "Query vector contains invalid values (NaN or Infinity)");
+ }
+ }
+ }
+
+ /** Convert distance metric type */
+ private DistanceType toDistanceType(MetricType metricType) {
+ switch (metricType) {
+ case L2:
+ return DistanceType.L2;
+ case COSINE:
+ return DistanceType.Cosine;
+ case DOT:
+ return DistanceType.Dot;
+ default:
+ return DistanceType.L2;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (dataset != null) {
+ try {
+ dataset.close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close dataset", e);
+ }
+ dataset = null;
+ }
- public Builder columnName(String columnName) {
- this.columnName = columnName;
- return this;
- }
+ if (allocator != null) {
+ try {
+ allocator.close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close allocator", e);
+ }
+ allocator = null;
+ }
+ }
+
+ /** Get RowType */
+ public RowType getRowType() {
+ return rowType;
+ }
+
+ /** Create builder */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** Create vector searcher from LanceOptions */
+ public static LanceVectorSearch fromOptions(LanceOptions options) {
+ return builder()
+ .datasetPath(options.getPath())
+ .columnName(options.getVectorColumn())
+ .metricType(options.getVectorMetric())
+ .nprobes(options.getVectorNprobes())
+ .ef(options.getVectorEf())
+ .refineFactor(options.getVectorRefineFactor())
+ .build();
+ }
+
+ /** Builder */
+ public static class Builder {
+ private String datasetPath;
+ private String columnName;
+ private MetricType metricType = MetricType.L2;
+ private int nprobes = 20;
+ private int ef = 100;
+ private Integer refineFactor;
+
+ public Builder datasetPath(String datasetPath) {
+ this.datasetPath = datasetPath;
+ return this;
+ }
- public Builder metricType(MetricType metricType) {
- this.metricType = metricType;
- return this;
- }
+ public Builder columnName(String columnName) {
+ this.columnName = columnName;
+ return this;
+ }
- public Builder nprobes(int nprobes) {
- this.nprobes = nprobes;
- return this;
- }
+ public Builder metricType(MetricType metricType) {
+ this.metricType = metricType;
+ return this;
+ }
- public Builder ef(int ef) {
- this.ef = ef;
- return this;
- }
+ public Builder nprobes(int nprobes) {
+ this.nprobes = nprobes;
+ return this;
+ }
- public Builder refineFactor(Integer refineFactor) {
- this.refineFactor = refineFactor;
- return this;
- }
+ public Builder ef(int ef) {
+ this.ef = ef;
+ return this;
+ }
- public LanceVectorSearch build() {
- validate();
- return new LanceVectorSearch(this);
- }
+ public Builder refineFactor(Integer refineFactor) {
+ this.refineFactor = refineFactor;
+ return this;
+ }
- private void validate() {
- if (datasetPath == null || datasetPath.isEmpty()) {
- throw new IllegalArgumentException("Dataset path cannot be empty");
- }
- if (columnName == null || columnName.isEmpty()) {
- throw new IllegalArgumentException("Column name cannot be empty");
- }
- if (nprobes <= 0) {
- throw new IllegalArgumentException("nprobes must be greater than 0");
- }
- }
+ public LanceVectorSearch build() {
+ validate();
+ return new LanceVectorSearch(this);
}
- /**
- * Search result
- */
- public static class SearchResult implements Serializable {
- private static final long serialVersionUID = 1L;
+ private void validate() {
+ if (datasetPath == null || datasetPath.isEmpty()) {
+ throw new IllegalArgumentException("Dataset path cannot be empty");
+ }
+ if (columnName == null || columnName.isEmpty()) {
+ throw new IllegalArgumentException("Column name cannot be empty");
+ }
+ if (nprobes <= 0) {
+ throw new IllegalArgumentException("nprobes must be greater than 0");
+ }
+ }
+ }
- private final RowData rowData;
- private final double distance;
+ /** Search result */
+ public static class SearchResult implements Serializable {
+ private static final long serialVersionUID = 1L;
- public SearchResult(RowData rowData, double distance) {
- this.rowData = rowData;
- this.distance = distance;
- }
+ private final RowData rowData;
+ private final double distance;
- public RowData getRowData() {
- return rowData;
- }
+ public SearchResult(RowData rowData, double distance) {
+ this.rowData = rowData;
+ this.distance = distance;
+ }
- public double getDistance() {
- return distance;
- }
+ public RowData getRowData() {
+ return rowData;
+ }
- /**
- * Get similarity score (inverse or negative of distance, depending on distance type)
- */
- public double getSimilarity() {
- if (distance == 0) {
- return 1.0;
- }
- // For L2 distance, use 1 / (1 + distance) as similarity
- return 1.0 / (1.0 + distance);
- }
+ public double getDistance() {
+ return distance;
+ }
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- SearchResult that = (SearchResult) o;
- return Double.compare(that.distance, distance) == 0 &&
- Objects.equals(rowData, that.rowData);
- }
+ /** Get similarity score (inverse or negative of distance, depending on distance type) */
+ public double getSimilarity() {
+ if (distance == 0) {
+ return 1.0;
+ }
+ // For L2 distance, use 1 / (1 + distance) as similarity
+ return 1.0 / (1.0 + distance);
+ }
- @Override
- public int hashCode() {
- return Objects.hash(rowData, distance);
- }
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ SearchResult that = (SearchResult) o;
+ return Double.compare(that.distance, distance) == 0 && Objects.equals(rowData, that.rowData);
+ }
- @Override
- public String toString() {
- return "SearchResult{" +
- "rowData=" + rowData +
- ", distance=" + distance +
- '}';
- }
+ @Override
+ public int hashCode() {
+ return Objects.hash(rowData, distance);
+ }
+
+ @Override
+ public String toString() {
+ return "SearchResult{" + "rowData=" + rowData + ", distance=" + distance + '}';
}
+ }
}
diff --git a/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateExecutor.java b/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateExecutor.java
index 5509e14..d51bc5c 100644
--- a/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateExecutor.java
+++ b/src/main/java/org/apache/flink/connector/lance/aggregate/AggregateExecutor.java
@@ -1,11 +1,7 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -15,7 +11,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.flink.connector.lance.aggregate;
import org.apache.flink.table.data.DecimalData;
@@ -41,518 +36,493 @@
/**
* Aggregate executor.
- *
- * Executes aggregate calculations at data source side, supports COUNT, SUM, AVG, MIN, MAX and other aggregate functions.
+ *
+ *
Executes aggregate calculations at data source side, supports COUNT, SUM, AVG, MIN, MAX and
+ * other aggregate functions.
*/
public class AggregateExecutor implements Serializable {
- private static final long serialVersionUID = 1L;
- private static final Logger LOG = LoggerFactory.getLogger(AggregateExecutor.class);
+ private static final long serialVersionUID = 1L;
+ private static final Logger LOG = LoggerFactory.getLogger(AggregateExecutor.class);
- private final AggregateInfo aggregateInfo;
- private final RowType sourceRowType;
+ private final AggregateInfo aggregateInfo;
+ private final RowType sourceRowType;
- // Aggregate state (by group key)
- private transient Map aggregateStates;
- private transient boolean initialized;
+ // Aggregate state (by group key)
+ private transient Map aggregateStates;
+ private transient boolean initialized;
- public AggregateExecutor(AggregateInfo aggregateInfo, RowType sourceRowType) {
- this.aggregateInfo = aggregateInfo;
- this.sourceRowType = sourceRowType;
- }
+ public AggregateExecutor(AggregateInfo aggregateInfo, RowType sourceRowType) {
+ this.aggregateInfo = aggregateInfo;
+ this.sourceRowType = sourceRowType;
+ }
+
+ /** Initialize aggregate executor */
+ public void init() {
+ this.aggregateStates = new HashMap<>();
+ this.initialized = true;
+ LOG.info("Initialized aggregate executor: {}", aggregateInfo);
+ }
- /**
- * Initialize aggregate executor
- */
- public void init() {
- this.aggregateStates = new HashMap<>();
- this.initialized = true;
- LOG.info("Initialized aggregate executor: {}", aggregateInfo);
+ /** Accumulate a row to aggregate state */
+ public void accumulate(RowData row) {
+ if (!initialized) {
+ init();
}
- /**
- * Accumulate a row to aggregate state
- */
- public void accumulate(RowData row) {
- if (!initialized) {
- init();
+ // Extract group key
+ GroupKey groupKey = extractGroupKey(row);
+
+ // Get or create aggregate state
+ AggregateState state =
+ aggregateStates.computeIfAbsent(
+ groupKey, k -> new AggregateState(aggregateInfo.getAggregateCalls().size()));
+
+ // Update state for each aggregate function
+ List calls = aggregateInfo.getAggregateCalls();
+ for (int i = 0; i < calls.size(); i++) {
+ AggregateInfo.AggregateCall call = calls.get(i);
+ accumulateCall(state, i, call, row);
+ }
+ }
+
+ /** Accumulate single aggregate function */
+ private void accumulateCall(
+ AggregateState state, int index, AggregateInfo.AggregateCall call, RowData row) {
+ switch (call.getFunction()) {
+ case COUNT:
+ if (call.isCountStar()) {
+ // COUNT(*)
+ state.incrementCount(index);
+ } else {
+ // COUNT(column) - only count non-NULL values
+ int fieldIndex = getFieldIndex(call.getColumn());
+ if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
+ state.incrementCount(index);
+ }
+ }
+ break;
+
+ case COUNT_DISTINCT:
+ if (call.getColumn() != null) {
+ int fieldIndex = getFieldIndex(call.getColumn());
+ if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
+ Object value = extractValue(row, fieldIndex);
+ state.addDistinctValue(index, value);
+ }
+ }
+ break;
+
+ case SUM:
+ if (call.getColumn() != null) {
+ int fieldIndex = getFieldIndex(call.getColumn());
+ if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
+ Number value = extractNumericValue(row, fieldIndex);
+ if (value != null) {
+ state.addSum(index, value.doubleValue());
+ }
+ }
}
+ break;
- // Extract group key
- GroupKey groupKey = extractGroupKey(row);
-
- // Get or create aggregate state
- AggregateState state = aggregateStates.computeIfAbsent(groupKey,
- k -> new AggregateState(aggregateInfo.getAggregateCalls().size()));
-
- // Update state for each aggregate function
- List calls = aggregateInfo.getAggregateCalls();
- for (int i = 0; i < calls.size(); i++) {
- AggregateInfo.AggregateCall call = calls.get(i);
- accumulateCall(state, i, call, row);
+ case AVG:
+ if (call.getColumn() != null) {
+ int fieldIndex = getFieldIndex(call.getColumn());
+ if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
+ Number value = extractNumericValue(row, fieldIndex);
+ if (value != null) {
+ state.addForAvg(index, value.doubleValue());
+ }
+ }
}
- }
+ break;
- /**
- * Accumulate single aggregate function
- */
- private void accumulateCall(AggregateState state, int index,
- AggregateInfo.AggregateCall call, RowData row) {
- switch (call.getFunction()) {
- case COUNT:
- if (call.isCountStar()) {
- // COUNT(*)
- state.incrementCount(index);
- } else {
- // COUNT(column) - only count non-NULL values
- int fieldIndex = getFieldIndex(call.getColumn());
- if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
- state.incrementCount(index);
- }
- }
- break;
-
- case COUNT_DISTINCT:
- if (call.getColumn() != null) {
- int fieldIndex = getFieldIndex(call.getColumn());
- if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
- Object value = extractValue(row, fieldIndex);
- state.addDistinctValue(index, value);
- }
- }
- break;
-
- case SUM:
- if (call.getColumn() != null) {
- int fieldIndex = getFieldIndex(call.getColumn());
- if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
- Number value = extractNumericValue(row, fieldIndex);
- if (value != null) {
- state.addSum(index, value.doubleValue());
- }
- }
- }
- break;
-
- case AVG:
- if (call.getColumn() != null) {
- int fieldIndex = getFieldIndex(call.getColumn());
- if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
- Number value = extractNumericValue(row, fieldIndex);
- if (value != null) {
- state.addForAvg(index, value.doubleValue());
- }
- }
- }
- break;
-
- case MIN:
- if (call.getColumn() != null) {
- int fieldIndex = getFieldIndex(call.getColumn());
- if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
- Comparable> value = extractComparableValue(row, fieldIndex);
- if (value != null) {
- state.updateMin(index, value);
- }
- }
- }
- break;
-
- case MAX:
- if (call.getColumn() != null) {
- int fieldIndex = getFieldIndex(call.getColumn());
- if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
- Comparable> value = extractComparableValue(row, fieldIndex);
- if (value != null) {
- state.updateMax(index, value);
- }
- }
- }
- break;
+ case MIN:
+ if (call.getColumn() != null) {
+ int fieldIndex = getFieldIndex(call.getColumn());
+ if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
+ Comparable> value = extractComparableValue(row, fieldIndex);
+ if (value != null) {
+ state.updateMin(index, value);
+ }
+ }
}
- }
+ break;
- /**
- * Get aggregate results
- */
- public List getResults() {
- if (!initialized || aggregateStates.isEmpty()) {
- // If no data, return default aggregate result
- return getDefaultResults();
+ case MAX:
+ if (call.getColumn() != null) {
+ int fieldIndex = getFieldIndex(call.getColumn());
+ if (fieldIndex >= 0 && !row.isNullAt(fieldIndex)) {
+ Comparable> value = extractComparableValue(row, fieldIndex);
+ if (value != null) {
+ state.updateMax(index, value);
+ }
+ }
}
+ break;
+ }
+ }
- List results = new ArrayList<>();
- List calls = aggregateInfo.getAggregateCalls();
- List groupByCols = aggregateInfo.getGroupByColumns();
+ /** Get aggregate results */
+ public List getResults() {
+ if (!initialized || aggregateStates.isEmpty()) {
+ // If no data, return default aggregate result
+ return getDefaultResults();
+ }
- for (Map.Entry entry : aggregateStates.entrySet()) {
- GroupKey groupKey = entry.getKey();
- AggregateState state = entry.getValue();
+ List results = new ArrayList<>();
+ List calls = aggregateInfo.getAggregateCalls();
+ List groupByCols = aggregateInfo.getGroupByColumns();
- // Create result row: group columns + aggregate columns
- int totalFields = groupByCols.size() + calls.size();
- GenericRowData resultRow = new GenericRowData(totalFields);
+ for (Map.Entry entry : aggregateStates.entrySet()) {
+ GroupKey groupKey = entry.getKey();
+ AggregateState state = entry.getValue();
- // Fill group columns
- for (int i = 0; i < groupByCols.size(); i++) {
- resultRow.setField(i, groupKey.getValues()[i]);
- }
+ // Create result row: group columns + aggregate columns
+ int totalFields = groupByCols.size() + calls.size();
+ GenericRowData resultRow = new GenericRowData(totalFields);
- // Fill aggregate results
- for (int i = 0; i < calls.size(); i++) {
- AggregateInfo.AggregateCall call = calls.get(i);
- Object aggResult = getAggregateResult(state, i, call);
- resultRow.setField(groupByCols.size() + i, aggResult);
- }
+ // Fill group columns
+ for (int i = 0; i < groupByCols.size(); i++) {
+ resultRow.setField(i, groupKey.getValues()[i]);
+ }
- results.add(resultRow);
- }
+ // Fill aggregate results
+ for (int i = 0; i < calls.size(); i++) {
+ AggregateInfo.AggregateCall call = calls.get(i);
+ Object aggResult = getAggregateResult(state, i, call);
+ resultRow.setField(groupByCols.size() + i, aggResult);
+ }
- LOG.info("Aggregate execution completed, generated {} result rows", results.size());
- return results;
+ results.add(resultRow);
}
- /**
- * Get default aggregate result when no data
- */
- private List getDefaultResults() {
- // If has GROUP BY, no data means no result
- if (aggregateInfo.hasGroupBy()) {
- return new ArrayList<>();
- }
-
- // No GROUP BY, return default aggregate values
- List calls = aggregateInfo.getAggregateCalls();
- GenericRowData resultRow = new GenericRowData(calls.size());
-
- for (int i = 0; i < calls.size(); i++) {
- AggregateInfo.AggregateCall call = calls.get(i);
- switch (call.getFunction()) {
- case COUNT:
- case COUNT_DISTINCT:
- resultRow.setField(i, 0L);
- break;
- default:
- resultRow.setField(i, null);
- break;
- }
- }
+ LOG.info("Aggregate execution completed, generated {} result rows", results.size());
+ return results;
+ }
- List results = new ArrayList<>();
- results.add(resultRow);
- return results;
- }
-
- /**
- * Get single aggregate function result
- */
- private Object getAggregateResult(AggregateState state, int index,
- AggregateInfo.AggregateCall call) {
- switch (call.getFunction()) {
- case COUNT:
- return state.getCount(index);
- case COUNT_DISTINCT:
- return (long) state.getDistinctCount(index);
- case SUM:
- Double sum = state.getSum(index);
- return sum != null ? sum : null;
- case AVG:
- Double avg = state.getAvg(index);
- return avg != null ? avg : null;
- case MIN:
- return state.getMin(index);
- case MAX:
- return state.getMax(index);
- default:
- return null;
- }
+ /** Get default aggregate result when no data */
+ private List getDefaultResults() {
+ // If has GROUP BY, no data means no result
+ if (aggregateInfo.hasGroupBy()) {
+ return new ArrayList<>();
}
- /**
- * Extract group key
- */
- private GroupKey extractGroupKey(RowData row) {
- List groupByCols = aggregateInfo.getGroupByColumns();
- if (groupByCols.isEmpty()) {
- return GroupKey.EMPTY;
- }
-
- Object[] keyValues = new Object[groupByCols.size()];
- for (int i = 0; i < groupByCols.size(); i++) {
- int fieldIndex = getFieldIndex(groupByCols.get(i));
- if (fieldIndex >= 0) {
- keyValues[i] = extractValue(row, fieldIndex);
- }
- }
- return new GroupKey(keyValues);
+ // No GROUP BY, return default aggregate values
+ List calls = aggregateInfo.getAggregateCalls();
+ GenericRowData resultRow = new GenericRowData(calls.size());
+
+ for (int i = 0; i < calls.size(); i++) {
+ AggregateInfo.AggregateCall call = calls.get(i);
+ switch (call.getFunction()) {
+ case COUNT:
+ case COUNT_DISTINCT:
+ resultRow.setField(i, 0L);
+ break;
+ default:
+ resultRow.setField(i, null);
+ break;
+ }
}
- /**
- * Get field index
- */
- private int getFieldIndex(String columnName) {
- List fieldNames = sourceRowType.getFieldNames();
- return fieldNames.indexOf(columnName);
+ List results = new ArrayList<>();
+ results.add(resultRow);
+ return results;
+ }
+
+ /** Get single aggregate function result */
+ private Object getAggregateResult(
+ AggregateState state, int index, AggregateInfo.AggregateCall call) {
+ switch (call.getFunction()) {
+ case COUNT:
+ return state.getCount(index);
+ case COUNT_DISTINCT:
+ return (long) state.getDistinctCount(index);
+ case SUM:
+ Double sum = state.getSum(index);
+ return sum != null ? sum : null;
+ case AVG:
+ Double avg = state.getAvg(index);
+ return avg != null ? avg : null;
+ case MIN:
+ return state.getMin(index);
+ case MAX:
+ return state.getMax(index);
+ default:
+ return null;
}
+ }
- /**
- * Extract field value
- */
- private Object extractValue(RowData row, int fieldIndex) {
- if (row.isNullAt(fieldIndex)) {
- return null;
- }
+ /** Extract group key */
+ private GroupKey extractGroupKey(RowData row) {
+ List groupByCols = aggregateInfo.getGroupByColumns();
+ if (groupByCols.isEmpty()) {
+ return GroupKey.EMPTY;
+ }
- LogicalType fieldType = sourceRowType.getTypeAt(fieldIndex);
- switch (fieldType.getTypeRoot()) {
- case BOOLEAN:
- return row.getBoolean(fieldIndex);
- case TINYINT:
- return row.getByte(fieldIndex);
- case SMALLINT:
- return row.getShort(fieldIndex);
- case INTEGER:
- return row.getInt(fieldIndex);
- case BIGINT:
- return row.getLong(fieldIndex);
- case FLOAT:
- return row.getFloat(fieldIndex);
- case DOUBLE:
- return row.getDouble(fieldIndex);
- case CHAR:
- case VARCHAR:
- // Keep StringData type for group key and result output
- return row.getString(fieldIndex);
- case DECIMAL:
- DecimalType decType = (DecimalType) fieldType;
- DecimalData decData = row.getDecimal(fieldIndex, decType.getPrecision(), decType.getScale());
- return decData != null ? decData.toBigDecimal() : null;
- default:
- return null;
- }
+ Object[] keyValues = new Object[groupByCols.size()];
+ for (int i = 0; i < groupByCols.size(); i++) {
+ int fieldIndex = getFieldIndex(groupByCols.get(i));
+ if (fieldIndex >= 0) {
+ keyValues[i] = extractValue(row, fieldIndex);
+ }
+ }
+ return new GroupKey(keyValues);
+ }
+
+ /** Get field index */
+ private int getFieldIndex(String columnName) {
+ List fieldNames = sourceRowType.getFieldNames();
+ return fieldNames.indexOf(columnName);
+ }
+
+ /** Extract field value */
+ private Object extractValue(RowData row, int fieldIndex) {
+ if (row.isNullAt(fieldIndex)) {
+ return null;
}
- /**
- * Extract numeric type field value
- */
- private Number extractNumericValue(RowData row, int fieldIndex) {
- if (row.isNullAt(fieldIndex)) {
- return null;
- }
+ LogicalType fieldType = sourceRowType.getTypeAt(fieldIndex);
+ switch (fieldType.getTypeRoot()) {
+ case BOOLEAN:
+ return row.getBoolean(fieldIndex);
+ case TINYINT:
+ return row.getByte(fieldIndex);
+ case SMALLINT:
+ return row.getShort(fieldIndex);
+ case INTEGER:
+ return row.getInt(fieldIndex);
+ case BIGINT:
+ return row.getLong(fieldIndex);
+ case FLOAT:
+ return row.getFloat(fieldIndex);
+ case DOUBLE:
+ return row.getDouble(fieldIndex);
+ case CHAR:
+ case VARCHAR:
+ // Keep StringData type for group key and result output
+ return row.getString(fieldIndex);
+ case DECIMAL:
+ DecimalType decType = (DecimalType) fieldType;
+ DecimalData decData =
+ row.getDecimal(fieldIndex, decType.getPrecision(), decType.getScale());
+ return decData != null ? decData.toBigDecimal() : null;
+ default:
+ return null;
+ }
+ }
- LogicalType fieldType = sourceRowType.getTypeAt(fieldIndex);
- switch (fieldType.getTypeRoot()) {
- case TINYINT:
- return row.getByte(fieldIndex);
- case SMALLINT:
- return row.getShort(fieldIndex);
- case INTEGER:
- return row.getInt(fieldIndex);
- case BIGINT:
- return row.getLong(fieldIndex);
- case FLOAT:
- return row.getFloat(fieldIndex);
- case DOUBLE:
- return row.getDouble(fieldIndex);
- case DECIMAL:
- DecimalType decType = (DecimalType) fieldType;
- DecimalData decData = row.getDecimal(fieldIndex, decType.getPrecision(), decType.getScale());
- return decData != null ? decData.toBigDecimal() : null;
- default:
- return null;
- }
+ /** Extract numeric type field value */
+ private Number extractNumericValue(RowData row, int fieldIndex) {
+ if (row.isNullAt(fieldIndex)) {
+ return null;
}
- /**
- * Extract comparable type field value
- */
- @SuppressWarnings("unchecked")
- private Comparable> extractComparableValue(RowData row, int fieldIndex) {
- Object value = extractValue(row, fieldIndex);
- if (value instanceof Comparable) {
- return (Comparable>) value;
- }
+ LogicalType fieldType = sourceRowType.getTypeAt(fieldIndex);
+ switch (fieldType.getTypeRoot()) {
+ case TINYINT:
+ return row.getByte(fieldIndex);
+ case SMALLINT:
+ return row.getShort(fieldIndex);
+ case INTEGER:
+ return row.getInt(fieldIndex);
+ case BIGINT:
+ return row.getLong(fieldIndex);
+ case FLOAT:
+ return row.getFloat(fieldIndex);
+ case DOUBLE:
+ return row.getDouble(fieldIndex);
+ case DECIMAL:
+ DecimalType decType = (DecimalType) fieldType;
+ DecimalData decData =
+ row.getDecimal(fieldIndex, decType.getPrecision(), decType.getScale());
+ return decData != null ? decData.toBigDecimal() : null;
+ default:
return null;
}
+ }
+
+ /** Extract comparable type field value */
+ @SuppressWarnings("unchecked")
+ private Comparable> extractComparableValue(RowData row, int fieldIndex) {
+ Object value = extractValue(row, fieldIndex);
+ if (value instanceof Comparable) {
+ return (Comparable>) value;
+ }
+ return null;
+ }
- /**
- * Reset aggregate state
- */
- public void reset() {
- if (aggregateStates != null) {
- aggregateStates.clear();
- }
+ /** Reset aggregate state */
+ public void reset() {
+ if (aggregateStates != null) {
+ aggregateStates.clear();
}
+ }
- /**
- * Group key
- */
- private static class GroupKey implements Serializable {
- private static final long serialVersionUID = 1L;
-
- static final GroupKey EMPTY = new GroupKey(new Object[0]);
+ /** Group key */
+ private static class GroupKey implements Serializable {
+ private static final long serialVersionUID = 1L;
- private final Object[] values;
- private final int hashCode;
+ static final GroupKey EMPTY = new GroupKey(new Object[0]);
- GroupKey(Object[] values) {
- this.values = values;
- this.hashCode = Objects.hash((Object[]) values);
- }
+ private final Object[] values;
+ private final int hashCode;
- Object[] getValues() {
- return values;
- }
+ GroupKey(Object[] values) {
+ this.values = values;
+ this.hashCode = Objects.hash((Object[]) values);
+ }
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- GroupKey groupKey = (GroupKey) o;
- return java.util.Arrays.equals(values, groupKey.values);
- }
+ Object[] getValues() {
+ return values;
+ }
- @Override
- public int hashCode() {
- return hashCode;
- }
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ GroupKey groupKey = (GroupKey) o;
+ return java.util.Arrays.equals(values, groupKey.values);
}
- /**
- * Aggregate state
- */
- private static class AggregateState implements Serializable {
- private static final long serialVersionUID = 1L;
-
- private final long[] counts;
- private final double[] sums;
- private final long[] avgCounts; // For calculating AVG
- private final Comparable>[] mins;
- private final Comparable>[] maxs;
- private final Set