diff --git a/docs/snowflake_arrow_handling.md b/docs/snowflake_arrow_handling.md new file mode 100644 index 0000000000..05362d29d8 --- /dev/null +++ b/docs/snowflake_arrow_handling.md @@ -0,0 +1,46 @@ +# Snowflake Arrow Handling + +This document lists the Snowflake-specific Arrow metadata cases currently handled by Ladybug. + +## Supported Cases + +| Snowflake signal | Example metadata | Arrow physical storage | Ladybug result | Notes | +| --- | --- | --- | --- | --- | +| Raw Snowflake decimal type via `DATA_TYPE` | `DATA_TYPE=NUMBER(12,4)` | Any Arrow storage type | `DECIMAL(12,4)` | Parsed from Snowflake table-schema metadata. | +| Raw Snowflake decimal type via `DATA_TYPE` with implicit scale | `DATA_TYPE=NUMBER(18)` | Any Arrow storage type | `DECIMAL(18,0)` | Missing scale defaults to `0`. | +| Raw Snowflake decimal aliases via `DATA_TYPE` | `DATA_TYPE=NUMERIC(10,3)` or `DATA_TYPE=DECIMAL(10,3)` | Any Arrow storage type | `DECIMAL(10,3)` | Matching is case-insensitive and whitespace-tolerant. | +| Snowflake logical decimal metadata | `logicalType=FIXED`, `precision=7`, `scale=2` | Integer-backed Arrow (`INT8/16/32/64`, `UINT8/16/32/64`) | `DECIMAL(7,2)` | Used for query-result Arrow schemas. | +| Snowflake logical decimal metadata | `logicalType=FIXED`, `precision=9`, `scale=2` | Float-backed Arrow (`FLOAT`, `DOUBLE`) | `DECIMAL(9,2)` | Values are cast into decimal backing storage during scan. | +| Snowflake raw type fallback to logical metadata | malformed `DATA_TYPE` plus valid `logicalType=FIXED` metadata | Integer-backed or float-backed Arrow | `DECIMAL(p,s)` from `logicalType` metadata | If raw `DATA_TYPE` parsing fails, Snowflake `logicalType` parsing is tried next. | +| Snowflake metadata precedence over generic metadata | `DATA_TYPE=NUMBER(12,4)` plus generic `logicalType=DECIMAL`, `precision=9`, `scale=3` | Any Arrow storage type | `DECIMAL(12,4)` | Snowflake raw type metadata wins over generic metadata. | + +## Current Scope + +Only Snowflake decimal semantics are handled today. + +Specifically: + +- `NUMBER(p,s)` +- `NUMBER(p)` +- `NUMERIC(p,s)` +- `DECIMAL(p,s)` +- `logicalType=FIXED` + +## Not Yet Handled + +The Snowflake ADBC driver documents additional logical types that are not currently interpreted in a Snowflake-specific way here, including: + +- `real` +- `date` +- `time` +- `timestamp_ltz` +- `timestamp_ntz` +- `timestamp_tz` +- `text` +- `binary` +- `variant` +- `object` +- `array` +- `boolean` + +For those, Ladybug currently relies on the standard Arrow physical type unless future Snowflake-specific decoding is added. diff --git a/extension b/extension index 3a45fc419c..96191fdbbb 160000 --- a/extension +++ b/extension @@ -1 +1 @@ -Subproject commit 3a45fc419c222f8139c8db0f02a52e99b7f72c88 +Subproject commit 96191fdbbb0967501d934fdd77e8943784be78ff diff --git a/src/common/arrow/CMakeLists.txt b/src/common/arrow/CMakeLists.txt index e9659d26a4..664980d214 100644 --- a/src/common/arrow/CMakeLists.txt +++ b/src/common/arrow/CMakeLists.txt @@ -2,6 +2,10 @@ add_library(lbug_common_arrow OBJECT arrow_array_scan.cpp arrow_converter.cpp + arrow_schema_metadata_generic_decoder.cpp + arrow_schema_metadata_snowflake_decoder.cpp + arrow_schema_metadata_utils.cpp + arrow_schema_metadata.cpp arrow_null_mask_tree.cpp arrow_row_batch.cpp arrow_type.cpp) diff --git a/src/common/arrow/arrow_array_scan.cpp b/src/common/arrow/arrow_array_scan.cpp index 98270cb384..d452176123 100644 --- a/src/common/arrow/arrow_array_scan.cpp +++ b/src/common/arrow/arrow_array_scan.cpp @@ -1,9 +1,11 @@ #include "common/arrow/arrow_converter.h" +#include "common/arrow/arrow_schema_metadata.h" #include "common/exception/runtime.h" #include "common/types/int128_t.h" #include "common/types/interval_t.h" #include "common/types/types.h" #include "common/vector/value_vector.h" +#include "function/cast/functions/cast_decimal.h" #include "function/cast/functions/numeric_limits.h" namespace lbug { @@ -58,6 +60,74 @@ static void scanArrowArrayFixedSizePrimitiveAndCastTo(const ArrowArray* array, }); } +template +static void scanArrowArrayIntegerBackedDecimal(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + switch (outputVector.dataType.getPhysicalType()) { + case PhysicalTypeID::INT16: + return scanArrowArrayFixedSizePrimitiveAndCastTo(array, outputVector, mask, + srcOffset, dstOffset, count); + case PhysicalTypeID::INT32: + return scanArrowArrayFixedSizePrimitiveAndCastTo(array, outputVector, mask, + srcOffset, dstOffset, count); + case PhysicalTypeID::INT64: + return scanArrowArrayFixedSizePrimitiveAndCastTo(array, outputVector, mask, + srcOffset, dstOffset, count); + case PhysicalTypeID::INT128: + return scanArrowArrayFixedSizePrimitiveAndCastTo(array, outputVector, mask, + srcOffset, dstOffset, count); + default: + throw RuntimeException( + "Invalid decimal output type: " + + PhysicalTypeUtils::toString(outputVector.dataType.getPhysicalType())); + } +} + +template +static void castArrowArrayDecimalValue(SRC input, ValueVector& outputVector, uint64_t pos) { + DST output{}; + function::CastToDecimal::operation(input, output, outputVector, outputVector); + outputVector.setValue(pos, output); +} + +template +static void scanArrowArrayDecimalWithCastTo(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + auto arrayBuffer = (const SRC*)array->buffers[1]; + + mask->copyToValueVector(&outputVector, dstOffset, count); + + rowIter(outputVector, count, [&](auto i) { + if (!mask->isNull(i)) { + castArrowArrayDecimalValue(arrayBuffer[i + srcOffset], outputVector, + i + dstOffset); + } + }); +} + +template +static void scanArrowArrayDecimalWithCast(const ArrowArray* array, ValueVector& outputVector, + ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + switch (outputVector.dataType.getPhysicalType()) { + case PhysicalTypeID::INT16: + return scanArrowArrayDecimalWithCastTo(array, outputVector, mask, srcOffset, + dstOffset, count); + case PhysicalTypeID::INT32: + return scanArrowArrayDecimalWithCastTo(array, outputVector, mask, srcOffset, + dstOffset, count); + case PhysicalTypeID::INT64: + return scanArrowArrayDecimalWithCastTo(array, outputVector, mask, srcOffset, + dstOffset, count); + case PhysicalTypeID::INT128: + return scanArrowArrayDecimalWithCastTo(array, outputVector, mask, srcOffset, + dstOffset, count); + default: + throw RuntimeException( + "Invalid decimal output type: " + + PhysicalTypeUtils::toString(outputVector.dataType.getPhysicalType())); + } +} + template<> void scanArrowArrayFixedSizePrimitive(const ArrowArray* array, ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { @@ -408,7 +478,55 @@ static void scanArrowArrayRunEndEncoded(const ArrowSchema* schema, const ArrowAr void ArrowConverter::fromArrowArray(const ArrowSchema* schema, const ArrowArray* array, ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count) { + return fromArrowArray(schema, array, outputVector, mask, srcOffset, dstOffset, count, nullptr); +} + +void ArrowConverter::fromArrowArray(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count, const std::optional* logicalTypeInfo) { const auto arrowType = schema->format; + std::optional parsedLogicalTypeInfo; + if (logicalTypeInfo == nullptr) { + parsedLogicalTypeInfo = tryGetArrowLogicalTypeInfo(schema); + logicalTypeInfo = &parsedLogicalTypeInfo; + } + if (logicalTypeInfo->has_value() && + (*logicalTypeInfo)->type == ArrowLogicalTypeInfo::Type::DECIMAL) { + switch (arrowType[0]) { + case 'c': + return scanArrowArrayIntegerBackedDecimal(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'C': + return scanArrowArrayIntegerBackedDecimal(array, outputVector, mask, srcOffset, + dstOffset, count); + case 's': + return scanArrowArrayIntegerBackedDecimal(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'S': + return scanArrowArrayIntegerBackedDecimal(array, outputVector, mask, + srcOffset, dstOffset, count); + case 'i': + return scanArrowArrayIntegerBackedDecimal(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'I': + return scanArrowArrayIntegerBackedDecimal(array, outputVector, mask, + srcOffset, dstOffset, count); + case 'l': + return scanArrowArrayIntegerBackedDecimal(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'L': + return scanArrowArrayIntegerBackedDecimal(array, outputVector, mask, + srcOffset, dstOffset, count); + case 'f': + return scanArrowArrayDecimalWithCast(array, outputVector, mask, srcOffset, + dstOffset, count); + case 'g': + return scanArrowArrayDecimalWithCast(array, outputVector, mask, srcOffset, + dstOffset, count); + default: + break; + } + } if (array->dictionary != nullptr) { switch (arrowType[0]) { case 'c': diff --git a/src/common/arrow/arrow_schema_metadata.cpp b/src/common/arrow/arrow_schema_metadata.cpp new file mode 100644 index 0000000000..7819226d56 --- /dev/null +++ b/src/common/arrow/arrow_schema_metadata.cpp @@ -0,0 +1,25 @@ +#include "common/arrow/arrow_schema_metadata.h" + +#include "arrow_schema_metadata_internal.h" + +namespace lbug { +namespace common { + +std::optional tryGetArrowLogicalTypeInfo(const ArrowSchema* schema) { + if (schema == nullptr || schema->format == nullptr || schema->metadata == nullptr) { + return std::nullopt; + } + const auto metadata = readArrowMetadata(schema->metadata); + if (auto snowflakeRawDataTypeInfo = tryParseSnowflakeRawDataTypeInfo(metadata); + snowflakeRawDataTypeInfo.has_value()) { + return snowflakeRawDataTypeInfo; + } + if (auto snowflakeTypeInfo = tryParseSnowflakeLogicalTypeInfo(schema, metadata); + snowflakeTypeInfo.has_value()) { + return snowflakeTypeInfo; + } + return tryParseGenericIntegerBackedDecimalMetadata(schema, metadata); +} + +} // namespace common +} // namespace lbug diff --git a/src/common/arrow/arrow_schema_metadata_generic_decoder.cpp b/src/common/arrow/arrow_schema_metadata_generic_decoder.cpp new file mode 100644 index 0000000000..97d585707a --- /dev/null +++ b/src/common/arrow/arrow_schema_metadata_generic_decoder.cpp @@ -0,0 +1,35 @@ +#include "arrow_schema_metadata_internal.h" + +namespace lbug { +namespace common { + +std::optional tryParseGenericIntegerBackedDecimalMetadata( + const ArrowSchema* schema, const ArrowMetadataMap& metadata) { + if (!isIntegralArrowStorageType(schema->format)) { + return std::nullopt; + } + const auto logicalType = getMetadataValue(metadata, "logicaltype"); + if (!logicalType.has_value()) { + return std::nullopt; + } + const auto normalized = toLower(*logicalType); + if (normalized != "decimal" && normalized != "number" && normalized != "numeric") { + return std::nullopt; + } + const auto precision = getMetadataValue(metadata, "precision"); + const auto scale = getMetadataValue(metadata, "scale"); + if (!precision.has_value() || !scale.has_value()) { + return std::nullopt; + } + const auto parsedPrecision = tryParseUint32(*precision); + const auto parsedScale = tryParseUint32(*scale); + if (!parsedPrecision.has_value() || !parsedScale.has_value() || + !isValidDecimalParameters(*parsedPrecision, *parsedScale)) { + return std::nullopt; + } + return ArrowLogicalTypeInfo{ArrowLogicalTypeInfo::Source::GENERIC_METADATA, + ArrowLogicalTypeInfo::Type::DECIMAL, ArrowDecimalTypeInfo{*parsedPrecision, *parsedScale}}; +} + +} // namespace common +} // namespace lbug diff --git a/src/common/arrow/arrow_schema_metadata_internal.h b/src/common/arrow/arrow_schema_metadata_internal.h new file mode 100644 index 0000000000..3825fc4c6e --- /dev/null +++ b/src/common/arrow/arrow_schema_metadata_internal.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/arrow/arrow_schema_metadata.h" + +namespace lbug { +namespace common { + +using ArrowMetadataMap = std::map; + +bool isIntegralArrowStorageType(const char* arrowType); +bool isFloatingArrowStorageType(const char* arrowType); +std::string toLower(std::string value); +ArrowMetadataMap readArrowMetadata(const char* metadata); +std::string trim(std::string value); +std::vector splitCommaSeparated(std::string value); +std::optional getMetadataValue(const ArrowMetadataMap& metadata, + const std::string& key); +std::optional tryParseUint32(const std::string& value); +bool isValidDecimalParameters(uint32_t precision, uint32_t scale); + +std::optional tryParseSnowflakeRawDataTypeInfo( + const ArrowMetadataMap& metadata); +std::optional tryParseSnowflakeLogicalTypeInfo(const ArrowSchema* schema, + const ArrowMetadataMap& metadata); +std::optional tryParseGenericIntegerBackedDecimalMetadata( + const ArrowSchema* schema, const ArrowMetadataMap& metadata); + +} // namespace common +} // namespace lbug diff --git a/src/common/arrow/arrow_schema_metadata_snowflake_decoder.cpp b/src/common/arrow/arrow_schema_metadata_snowflake_decoder.cpp new file mode 100644 index 0000000000..d4253b88d6 --- /dev/null +++ b/src/common/arrow/arrow_schema_metadata_snowflake_decoder.cpp @@ -0,0 +1,86 @@ +#include "arrow_schema_metadata_internal.h" + +namespace lbug { +namespace common { + +namespace { + +std::optional tryParseSnowflakeDecimalType(const std::string& rawDataType) { + auto normalized = toLower(trim(rawDataType)); + const auto openParen = normalized.find('('); + if (openParen == std::string::npos) { + return std::nullopt; + } + const auto typeName = trim(normalized.substr(0, openParen)); + if (typeName != "number" && typeName != "numeric" && typeName != "decimal") { + return std::nullopt; + } + const auto closeParen = normalized.find(')', openParen + 1); + if (closeParen == std::string::npos) { + return std::nullopt; + } + auto args = splitCommaSeparated(normalized.substr(openParen + 1, closeParen - openParen - 1)); + if (args.empty() || args.size() > 2) { + return std::nullopt; + } + const auto precision = tryParseUint32(args[0]); + if (!precision.has_value()) { + return std::nullopt; + } + auto scale = std::optional{0}; + if (args.size() == 2) { + scale = tryParseUint32(args[1]); + } + if (!scale.has_value() || !isValidDecimalParameters(*precision, *scale)) { + return std::nullopt; + } + return ArrowDecimalTypeInfo{*precision, *scale}; +} + +} // namespace + +std::optional tryParseSnowflakeRawDataTypeInfo( + const ArrowMetadataMap& metadata) { + const auto rawDataType = getMetadataValue(metadata, "data_type"); + if (!rawDataType.has_value()) { + return std::nullopt; + } + const auto decimalInfo = tryParseSnowflakeDecimalType(*rawDataType); + if (!decimalInfo.has_value()) { + return std::nullopt; + } + return ArrowLogicalTypeInfo{ArrowLogicalTypeInfo::Source::SNOWFLAKE, + ArrowLogicalTypeInfo::Type::DECIMAL, *decimalInfo}; +} + +std::optional tryParseSnowflakeLogicalTypeInfo(const ArrowSchema* schema, + const ArrowMetadataMap& metadata) { + if (!isIntegralArrowStorageType(schema->format) && + !isFloatingArrowStorageType(schema->format)) { + return std::nullopt; + } + const auto logicalType = getMetadataValue(metadata, "logicaltype"); + if (!logicalType.has_value()) { + return std::nullopt; + } + const auto normalized = toLower(*logicalType); + if (normalized != "fixed") { + return std::nullopt; + } + const auto precision = getMetadataValue(metadata, "precision"); + const auto scale = getMetadataValue(metadata, "scale"); + if (!precision.has_value() || !scale.has_value()) { + return std::nullopt; + } + const auto parsedPrecision = tryParseUint32(*precision); + const auto parsedScale = tryParseUint32(*scale); + if (!parsedPrecision.has_value() || !parsedScale.has_value() || + !isValidDecimalParameters(*parsedPrecision, *parsedScale)) { + return std::nullopt; + } + return ArrowLogicalTypeInfo{ArrowLogicalTypeInfo::Source::SNOWFLAKE, + ArrowLogicalTypeInfo::Type::DECIMAL, ArrowDecimalTypeInfo{*parsedPrecision, *parsedScale}}; +} + +} // namespace common +} // namespace lbug diff --git a/src/common/arrow/arrow_schema_metadata_utils.cpp b/src/common/arrow/arrow_schema_metadata_utils.cpp new file mode 100644 index 0000000000..e8a9a23e15 --- /dev/null +++ b/src/common/arrow/arrow_schema_metadata_utils.cpp @@ -0,0 +1,145 @@ +#include +#include +#include +#include + +#include "arrow_schema_metadata_internal.h" +#include "common/constants.h" + +namespace lbug { +namespace common { + +bool isIntegralArrowStorageType(const char* arrowType) { + switch (arrowType[0]) { + case 'c': + case 'C': + case 's': + case 'S': + case 'i': + case 'I': + case 'l': + case 'L': + return true; + default: + return false; + } +} + +bool isFloatingArrowStorageType(const char* arrowType) { + return arrowType[0] == 'f' || arrowType[0] == 'g'; +} + +std::string toLower(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + return value; +} + +ArrowMetadataMap readArrowMetadata(const char* metadata) { + ArrowMetadataMap result; + if (metadata == nullptr) { + return result; + } + const auto* ptr = reinterpret_cast(metadata); + size_t bytesRead = 0; + auto tryReadInt32 = [&](int32_t& out) { + if (std::numeric_limits::max() - bytesRead < sizeof(int32_t)) { + return false; + } + memcpy(&out, ptr, sizeof(int32_t)); + ptr += sizeof(int32_t); + bytesRead += sizeof(int32_t); + return true; + }; + auto tryReadString = [&](std::string& out) { + int32_t len = 0; + if (!tryReadInt32(len) || len < 0) { + return false; + } + const auto lenSize = static_cast(len); + if (std::numeric_limits::max() - bytesRead < lenSize) { + return false; + } + out.assign(reinterpret_cast(ptr), lenSize); + ptr += lenSize; + bytesRead += lenSize; + return true; + }; + + int32_t numEntries = 0; + if (!tryReadInt32(numEntries) || numEntries < 0) { + return {}; + } + for (auto i = 0; i < numEntries; ++i) { + std::string key; + std::string value; + // ArrowSchema.metadata does not expose an outer byte length, so we can validate shape + // (negative lengths, arithmetic overflow) but not fully bounds-check truncated buffers. + if (!tryReadString(key) || !tryReadString(value)) { + return {}; + } + + result.emplace(toLower(std::move(key)), std::move(value)); + } + return result; +} + +std::string trim(std::string value) { + value.erase(value.begin(), + std::find_if(value.begin(), value.end(), [](unsigned char c) { return !std::isspace(c); })); + value.erase( + std::find_if(value.rbegin(), value.rend(), [](unsigned char c) { return !std::isspace(c); }) + .base(), + value.end()); + return value; +} + +std::vector splitCommaSeparated(std::string value) { + std::vector result; + size_t start = 0; + while (start <= value.size()) { + const auto end = value.find(',', start); + auto part = + end == std::string::npos ? value.substr(start) : value.substr(start, end - start); + result.push_back(trim(std::move(part))); + if (end == std::string::npos) { + break; + } + start = end + 1; + } + return result; +} + +std::optional getMetadataValue(const ArrowMetadataMap& metadata, + const std::string& key) { + const auto entry = metadata.find(key); + if (entry == metadata.end()) { + return std::nullopt; + } + return entry->second; +} + +std::optional tryParseUint32(const std::string& value) { + if (value.empty()) { + return std::nullopt; + } + uint32_t parsed = 0; + for (auto c : value) { + if (!std::isdigit(static_cast(c))) { + return std::nullopt; + } + const auto digit = static_cast(c - '0'); + if (parsed > (std::numeric_limits::max() - digit) / 10) { + return std::nullopt; + } + parsed = parsed * 10 + digit; + } + return parsed; +} + +bool isValidDecimalParameters(uint32_t precision, uint32_t scale) { + return precision > 0 && precision <= DECIMAL_PRECISION_LIMIT && scale <= precision; +} + +} // namespace common +} // namespace lbug diff --git a/src/common/arrow/arrow_type.cpp b/src/common/arrow/arrow_type.cpp index a82aa47d4b..aadc1c5c5e 100644 --- a/src/common/arrow/arrow_type.cpp +++ b/src/common/arrow/arrow_type.cpp @@ -1,4 +1,5 @@ #include "common/arrow/arrow_converter.h" +#include "common/arrow/arrow_schema_metadata.h" #include "common/exception/not_implemented.h" #include "common/string_utils.h" @@ -16,6 +17,13 @@ LogicalType ArrowConverter::fromArrowSchema(const ArrowSchema* schema) { if (schema->dictionary != nullptr) { return fromArrowSchema(schema->dictionary); } + if (auto logicalTypeInfo = tryGetArrowLogicalTypeInfo(schema); logicalTypeInfo.has_value()) { + switch (logicalTypeInfo->type) { + case ArrowLogicalTypeInfo::Type::DECIMAL: + return LogicalType::DECIMAL(logicalTypeInfo->decimal.precision, + logicalTypeInfo->decimal.scale); + } + } switch (arrowType[0]) { case 'n': return LogicalType(LogicalTypeID::ANY); diff --git a/src/include/common/arrow/arrow_converter.h b/src/include/common/arrow/arrow_converter.h index f35e638a67..f42eafc737 100644 --- a/src/include/common/arrow/arrow_converter.h +++ b/src/include/common/arrow/arrow_converter.h @@ -3,8 +3,10 @@ #include #include +#include "common/api.h" #include "common/arrow/arrow.h" #include "common/arrow/arrow_nullmask_tree.h" +#include "common/arrow/arrow_schema_metadata.h" struct ArrowSchema; @@ -20,7 +22,7 @@ struct ArrowSchemaHolder { std::vector> ownedMetadatas; }; -class ArrowConverter { +class LBUG_API ArrowConverter { public: static std::unique_ptr toArrowSchema(const std::vector& dataTypes, const std::vector& columnNames, bool fallbackExtensionTypes); @@ -29,6 +31,9 @@ class ArrowConverter { static void fromArrowArray(const ArrowSchema* schema, const ArrowArray* array, ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, uint64_t count); + static void fromArrowArray(const ArrowSchema* schema, const ArrowArray* array, + ValueVector& outputVector, ArrowNullMaskTree* mask, uint64_t srcOffset, uint64_t dstOffset, + uint64_t count, const std::optional* logicalTypeInfo); static void fromArrowArray(const ArrowSchema* schema, const ArrowArray* array, ValueVector& outputVector); diff --git a/src/include/common/arrow/arrow_schema_metadata.h b/src/include/common/arrow/arrow_schema_metadata.h new file mode 100644 index 0000000000..ff537b9203 --- /dev/null +++ b/src/include/common/arrow/arrow_schema_metadata.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +#include "common/api.h" +#include "common/arrow/arrow.h" + +namespace lbug { +namespace common { + +struct ArrowDecimalTypeInfo { + uint32_t precision; + uint32_t scale; +}; + +struct ArrowLogicalTypeInfo { + enum class Source : uint8_t { + SNOWFLAKE, + GENERIC_METADATA, + }; + + enum class Type : uint8_t { + DECIMAL, + }; + + Source source; + Type type; + ArrowDecimalTypeInfo decimal; +}; + +LBUG_API std::optional tryGetArrowLogicalTypeInfo(const ArrowSchema* schema); + +} // namespace common +} // namespace lbug diff --git a/src/include/storage/table/arrow_node_table.h b/src/include/storage/table/arrow_node_table.h index febe709207..072632a9b4 100644 --- a/src/include/storage/table/arrow_node_table.h +++ b/src/include/storage/table/arrow_node_table.h @@ -6,6 +6,7 @@ #include "catalog/catalog_entry/node_table_catalog_entry.h" #include "common/arrow/arrow.h" +#include "common/arrow/arrow_schema_metadata.h" #include "common/cast.h" #include "common/exception/runtime.h" #include "function/table/table_function.h" @@ -122,6 +123,7 @@ class ArrowNodeTable final : public ColumnarNodeTableBase { private: ArrowSchemaWrapper schema; std::vector arrays; + std::vector> columnLogicalTypeInfos; std::vector batchStartOffsets; size_t totalRows; std::string arrowId; // ID in registry for cleanup diff --git a/src/include/storage/table/arrow_rel_table.h b/src/include/storage/table/arrow_rel_table.h index 7f6d8c9283..3f49439824 100644 --- a/src/include/storage/table/arrow_rel_table.h +++ b/src/include/storage/table/arrow_rel_table.h @@ -7,6 +7,7 @@ #include "catalog/catalog_entry/rel_group_catalog_entry.h" #include "common/arrow/arrow.h" +#include "common/arrow/arrow_schema_metadata.h" #include "storage/table/arrow_table_support.h" #include "storage/table/columnar_rel_table_base.h" #include "storage/table/node_table.h" @@ -69,6 +70,7 @@ class ArrowRelTable final : public ColumnarRelTableBase { common::offset_t findCSRSourceOffset(common::offset_t relOffset) const; bool readArrowValueAtOffset(const ArrowSchemaWrapper& arrowSchema, const std::vector& arrowArrays, const std::vector& startOffsets, + const std::vector>& logicalTypeInfos, int64_t columnIdx, common::offset_t rowOffset, common::ValueVector& outputVector, uint64_t dstOffset) const; @@ -77,9 +79,11 @@ class ArrowRelTable final : public ColumnarRelTableBase { ArrowRelTableLayout layout; ArrowSchemaWrapper schema; std::vector arrays; + std::vector> columnLogicalTypeInfos; std::vector batchStartOffsets; ArrowSchemaWrapper indptrSchema; std::vector indptrArrays; + std::vector> indptrColumnLogicalTypeInfos; std::vector indptrBatchStartOffsets; std::unordered_map propertyColumnToArrowColumnIdx; size_t totalRows = 0; diff --git a/src/storage/table/arrow_node_table.cpp b/src/storage/table/arrow_node_table.cpp index a139e14647..2e00e949b9 100644 --- a/src/storage/table/arrow_node_table.cpp +++ b/src/storage/table/arrow_node_table.cpp @@ -24,6 +24,21 @@ static uint64_t getArrowBatchLength(const ArrowArrayWrapper& array) { return 0; } +static std::vector> resolveColumnLogicalTypeInfos( + const ArrowSchemaWrapper& schema) { + std::vector> result; + if (schema.n_children <= 0 || schema.children == nullptr) { + return result; + } + result.reserve(static_cast(schema.n_children)); + for (int64_t i = 0; i < schema.n_children; ++i) { + result.push_back(schema.children[i] == nullptr ? + std::nullopt : + common::tryGetArrowLogicalTypeInfo(schema.children[i])); + } + return result; +} + ArrowNodeTable::ArrowNodeTable(const StorageManager* storageManager, const catalog::NodeTableCatalogEntry* nodeTableEntry, MemoryManager* memoryManager, ArrowSchemaWrapper schema, std::vector arrays, std::string arrowId) @@ -35,6 +50,7 @@ ArrowNodeTable::ArrowNodeTable(const StorageManager* storageManager, if (!this->schema.format) { throw common::RuntimeException("Arrow schema format cannot be null"); } + columnLogicalTypeInfos = resolveColumnLogicalTypeInfos(this->schema); batchStartOffsets.reserve(this->arrays.size()); for (const auto& array : this->arrays) { batchStartOffsets.push_back(totalRows); @@ -203,10 +219,14 @@ void ArrowNodeTable::copyArrowMorselToOutputVectors(const ArrowArrayWrapper& bat auto& outputVector = *outputVectors[outCol]; auto* childArray = batch.children[arrowColIdx]; auto* childSchema = schema.children[arrowColIdx]; + const auto* logicalTypeInfo = + static_cast(arrowColIdx) < columnLogicalTypeInfos.size() ? + &columnLogicalTypeInfos[arrowColIdx] : + nullptr; common::ArrowNullMaskTree nullMask(childSchema, childArray, childArray->offset, childArray->length); common::ArrowConverter::fromArrowArray(childSchema, childArray, outputVector, &nullMask, - childArray->offset + currentMorselStartOffset, 0, numRowsToCopy); + childArray->offset + currentMorselStartOffset, 0, numRowsToCopy, logicalTypeInfo); } } @@ -245,11 +265,15 @@ bool ArrowNodeTable::lookupPK([[maybe_unused]] const transaction::Transaction* t } auto* pkChildArray = batch.children[pkArrowColumnIdx]; auto* pkChildSchema = schema.children[pkArrowColumnIdx]; + const auto* logicalTypeInfo = + static_cast(pkArrowColumnIdx) < columnLogicalTypeInfos.size() ? + &columnLogicalTypeInfos[pkArrowColumnIdx] : + nullptr; common::ArrowNullMaskTree nullMask(pkChildSchema, pkChildArray, pkChildArray->offset, pkChildArray->length); for (uint64_t rowIdx = 0; rowIdx < batchLength; ++rowIdx) { common::ArrowConverter::fromArrowArray(pkChildSchema, pkChildArray, *arrowPKVector, - &nullMask, pkChildArray->offset + rowIdx, 0, 1); + &nullMask, pkChildArray->offset + rowIdx, 0, 1, logicalTypeInfo); if (arrowPKVector->isNull(0)) { continue; } diff --git a/src/storage/table/arrow_rel_table.cpp b/src/storage/table/arrow_rel_table.cpp index 327443eaf5..49d063e5f5 100644 --- a/src/storage/table/arrow_rel_table.cpp +++ b/src/storage/table/arrow_rel_table.cpp @@ -38,6 +38,21 @@ static int64_t findColumnIdx(const ArrowSchemaWrapper& schema, const std::string return -1; } +static std::vector> resolveColumnLogicalTypeInfos( + const ArrowSchemaWrapper& schema) { + std::vector> result; + if (schema.n_children <= 0 || schema.children == nullptr) { + return result; + } + result.reserve(static_cast(schema.n_children)); + for (int64_t i = 0; i < schema.n_children; ++i) { + result.push_back(schema.children[i] == nullptr ? + std::nullopt : + tryGetArrowLogicalTypeInfo(schema.children[i])); + } + return result; +} + void ArrowRelTableScanState::setToTable(const transaction::Transaction* transaction, Table* table_, std::vector columnIDs_, std::vector columnPredicateSets_, RelDataDirection direction_) { @@ -72,6 +87,8 @@ ArrowRelTable::ArrowRelTable(catalog::RelGroupCatalogEntry* relGroupEntry, table if (!this->schema.format) { throw RuntimeException("Arrow schema format cannot be null"); } + columnLogicalTypeInfos = resolveColumnLogicalTypeInfos(this->schema); + indptrColumnLogicalTypeInfos = resolveColumnLogicalTypeInfos(this->indptrSchema); if (!this->fromNodeTable || !this->toNodeTable) { throw RuntimeException( "Arrow relationship table requires source and destination node tables"); @@ -209,9 +226,11 @@ void ArrowRelTable::initScanState([[maybe_unused]] transaction::Transaction* tra } static void readSingleArrowValue(const ArrowSchema* schema, const ArrowArray* array, - ValueVector& outputVector, uint64_t srcOffset, uint64_t dstOffset) { + ValueVector& outputVector, uint64_t srcOffset, uint64_t dstOffset, + const std::optional* logicalTypeInfo = nullptr) { ArrowNullMaskTree nullMask(schema, array, array->offset, array->length); - ArrowConverter::fromArrowArray(schema, array, outputVector, &nullMask, srcOffset, dstOffset, 1); + ArrowConverter::fromArrowArray(schema, array, outputVector, &nullMask, srcOffset, dstOffset, 1, + logicalTypeInfo); } bool ArrowRelTable::scanInternal(transaction::Transaction* transaction, TableScanState& scanState) { @@ -262,14 +281,22 @@ bool ArrowRelTable::scanFlat(transaction::Transaction* transaction, TableScanSta auto* dstChildSchema = schema.children[toColumnIdx]; auto srcOffsetToRead = srcChildArray->offset + srcOffsetInBatch; auto dstOffsetToRead = dstChildArray->offset + srcOffsetInBatch; + const auto* srcLogicalTypeInfo = + static_cast(fromColumnIdx) < columnLogicalTypeInfos.size() ? + &columnLogicalTypeInfos[fromColumnIdx] : + nullptr; + const auto* dstLogicalTypeInfo = + static_cast(toColumnIdx) < columnLogicalTypeInfos.size() ? + &columnLogicalTypeInfos[toColumnIdx] : + nullptr; readSingleArrowValue(srcChildSchema, srcChildArray, *relScanState.arrowSrcKeyVector, - srcOffsetToRead, 0); + srcOffsetToRead, 0, srcLogicalTypeInfo); if (relScanState.arrowSrcKeyVector->isNull(0)) { relScanState.arrowCurrentBatchOffset++; continue; } readSingleArrowValue(dstChildSchema, dstChildArray, *relScanState.arrowDstKeyVector, - dstOffsetToRead, 0); + dstOffsetToRead, 0, dstLogicalTypeInfo); if (relScanState.arrowDstKeyVector->isNull(0)) { relScanState.arrowCurrentBatchOffset++; continue; @@ -331,8 +358,12 @@ bool ArrowRelTable::scanFlat(transaction::Transaction* transaction, TableScanSta } auto* childArray = batch.children[arrowColIdx]; auto* childSchema = schema.children[arrowColIdx]; + const auto* logicalTypeInfo = + static_cast(arrowColIdx) < columnLogicalTypeInfos.size() ? + &columnLogicalTypeInfos[arrowColIdx] : + nullptr; readSingleArrowValue(childSchema, childArray, *relScanState.outputVectors[outCol], - childArray->offset + srcOffsetInBatch, outputCount); + childArray->offset + srcOffsetInBatch, outputCount, logicalTypeInfo); } outputCount++; relScanState.arrowCurrentBatchOffset++; @@ -426,7 +457,7 @@ bool ArrowRelTable::scanCSR(TableScanState& scanState) { if (outCol >= outputToArrowColumnIdx.size() || outputToArrowColumnIdx[outCol] < 0) { continue; } - readArrowValueAtOffset(schema, arrays, batchStartOffsets, + readArrowValueAtOffset(schema, arrays, batchStartOffsets, columnLogicalTypeInfos, outputToArrowColumnIdx[outCol], relOffset, *relScanState.outputVectors[outCol], outputCount); } @@ -476,7 +507,7 @@ bool ArrowRelTable::scanCSR(TableScanState& scanState) { if (outCol >= outputToArrowColumnIdx.size() || outputToArrowColumnIdx[outCol] < 0) { continue; } - readArrowValueAtOffset(schema, arrays, batchStartOffsets, + readArrowValueAtOffset(schema, arrays, batchStartOffsets, columnLogicalTypeInfos, outputToArrowColumnIdx[outCol], relOffset, *relScanState.outputVectors[outCol], outputCount); } @@ -509,8 +540,8 @@ bool ArrowRelTable::scanCSR(TableScanState& scanState) { bool ArrowRelTable::readCSRValue(ValueVector& outputVector, offset_t relOffset, uint64_t dstOffset) const { - return readArrowValueAtOffset(schema, arrays, batchStartOffsets, csrNbrColumnIdx, relOffset, - outputVector, dstOffset); + return readArrowValueAtOffset(schema, arrays, batchStartOffsets, columnLogicalTypeInfos, + csrNbrColumnIdx, relOffset, outputVector, dstOffset); } bool ArrowRelTable::readIndptr(offset_t srcOffset, offset_t& result) const { @@ -518,7 +549,7 @@ bool ArrowRelTable::readIndptr(offset_t srcOffset, offset_t& result) const { ValueVector valueVector{LogicalType::UINT64(), memoryManager, singleValueState}; valueVector.state->setToFlat(); if (!readArrowValueAtOffset(indptrSchema, indptrArrays, indptrBatchStartOffsets, - csrIndptrColumnIdx, srcOffset, valueVector, 0) || + indptrColumnLogicalTypeInfos, csrIndptrColumnIdx, srcOffset, valueVector, 0) || valueVector.isNull(0)) { return false; } @@ -555,7 +586,8 @@ offset_t ArrowRelTable::findCSRSourceOffset(offset_t relOffset) const { bool ArrowRelTable::readArrowValueAtOffset(const ArrowSchemaWrapper& arrowSchema, const std::vector& arrowArrays, const std::vector& startOffsets, - int64_t columnIdx, offset_t rowOffset, ValueVector& outputVector, uint64_t dstOffset) const { + const std::vector>& logicalTypeInfos, int64_t columnIdx, + offset_t rowOffset, ValueVector& outputVector, uint64_t dstOffset) const { if (columnIdx < 0 || arrowArrays.empty() || startOffsets.size() != arrowArrays.size()) { return false; } @@ -575,8 +607,11 @@ bool ArrowRelTable::readArrowValueAtOffset(const ArrowSchemaWrapper& arrowSchema } auto* childArray = batch.children[columnIdx]; auto* childSchema = arrowSchema.children[columnIdx]; + const auto* logicalTypeInfo = static_cast(columnIdx) < logicalTypeInfos.size() ? + &logicalTypeInfos[columnIdx] : + nullptr; readSingleArrowValue(childSchema, childArray, outputVector, childArray->offset + rowInBatch, - dstOffset); + dstOffset, logicalTypeInfo); return true; } return false; diff --git a/test/api/arrow_test.cpp b/test/api/arrow_test.cpp index 4de3f07394..721b1a3790 100644 --- a/test/api/arrow_test.cpp +++ b/test/api/arrow_test.cpp @@ -1,11 +1,60 @@ +#include +#include +#include + #include "api_test/api_test.h" +#include "arrow_test_utils.h" +#include "common/arrow/arrow_converter.h" +#include "common/arrow/arrow_schema_metadata.h" #include "common/exception/runtime.h" +#include "common/types/types.h" +#include "common/vector/value_vector.h" #include "main/query_result/arrow_query_result.h" using namespace lbug::common; using namespace lbug::main; using namespace lbug::testing; +namespace { + +std::vector serializeArrowMetadata( + const std::vector>& entries) { + auto size = sizeof(int32_t); + for (const auto& [key, value] : entries) { + size += sizeof(int32_t) + key.size() + sizeof(int32_t) + value.size(); + } + std::vector bytes(size); + auto* ptr = bytes.data(); + const auto numEntries = static_cast(entries.size()); + memcpy(ptr, &numEntries, sizeof(int32_t)); + ptr += sizeof(int32_t); + for (const auto& [key, value] : entries) { + const auto keySize = static_cast(key.size()); + memcpy(ptr, &keySize, sizeof(int32_t)); + ptr += sizeof(int32_t); + memcpy(ptr, key.data(), key.size()); + ptr += key.size(); + const auto valueSize = static_cast(value.size()); + memcpy(ptr, &valueSize, sizeof(int32_t)); + ptr += sizeof(int32_t); + memcpy(ptr, value.data(), value.size()); + ptr += value.size(); + } + return bytes; +} + +void appendInt32(std::vector& bytes, int32_t value) { + const auto* ptr = reinterpret_cast(&value); + bytes.insert(bytes.end(), ptr, ptr + sizeof(int32_t)); +} + +void appendString(std::vector& bytes, const std::string& value) { + appendInt32(bytes, static_cast(value.size())); + bytes.insert(bytes.end(), value.begin(), value.end()); +} + +} // namespace + class ArrowTest : public ApiTest {}; static void releaseCSRArrowArray(ArrowQueryResult::CSRArrowArray& array) { @@ -54,6 +103,331 @@ TEST(ArrowQueryResultTest, exportsCSRMetadataAsZeroCopyArrowArrays) { (std::vector{10, 11, 12})); } +TEST(ArrowConverterTest, bindsIntegerBackedSnowflakeDecimalMetadataAsDecimal) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = + serializeArrowMetadata({{"logicalType", "FIXED"}, {"precision", "7"}, {"scale", "2"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + const auto logicalTypeInfo = tryGetArrowLogicalTypeInfo(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::DECIMAL); + ASSERT_EQ(DecimalType::getPrecision(type), 7u); + ASSERT_EQ(DecimalType::getScale(type), 2u); + ASSERT_TRUE(logicalTypeInfo.has_value()); + ASSERT_EQ(logicalTypeInfo->source, ArrowLogicalTypeInfo::Source::SNOWFLAKE); + schema.release(&schema); +} + +TEST(ArrowConverterTest, bindsSnowflakeRawDataTypeMetadataAsDecimal) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = serializeArrowMetadata({{"DATA_TYPE", "NUMBER(12, 4)"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + const auto logicalTypeInfo = tryGetArrowLogicalTypeInfo(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::DECIMAL); + ASSERT_EQ(DecimalType::getPrecision(type), 12u); + ASSERT_EQ(DecimalType::getScale(type), 4u); + ASSERT_TRUE(logicalTypeInfo.has_value()); + ASSERT_EQ(logicalTypeInfo->source, ArrowLogicalTypeInfo::Source::SNOWFLAKE); + schema.release(&schema); +} + +TEST(ArrowConverterTest, bindsSnowflakeRawNumberMetadataWithoutExplicitScaleAsDecimal) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = serializeArrowMetadata({{"DATA_TYPE", "NUMBER(18)"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + const auto logicalTypeInfo = tryGetArrowLogicalTypeInfo(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::DECIMAL); + ASSERT_EQ(DecimalType::getPrecision(type), 18u); + ASSERT_EQ(DecimalType::getScale(type), 0u); + ASSERT_TRUE(logicalTypeInfo.has_value()); + ASSERT_EQ(logicalTypeInfo->source, ArrowLogicalTypeInfo::Source::SNOWFLAKE); + schema.release(&schema); +} + +TEST(ArrowConverterTest, bindsSnowflakeRawNumericMetadataWithWhitespaceAsDecimal) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = serializeArrowMetadata({{"DATA_TYPE", " numeric ( 10 , 3 ) "}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::DECIMAL); + ASSERT_EQ(DecimalType::getPrecision(type), 10u); + ASSERT_EQ(DecimalType::getScale(type), 3u); + schema.release(&schema); +} + +TEST(ArrowConverterTest, bindsGenericIntegerBackedDecimalMetadataAsDecimal) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = + serializeArrowMetadata({{"logicalType", "DECIMAL"}, {"precision", "9"}, {"scale", "3"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + const auto logicalTypeInfo = tryGetArrowLogicalTypeInfo(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::DECIMAL); + ASSERT_EQ(DecimalType::getPrecision(type), 9u); + ASSERT_EQ(DecimalType::getScale(type), 3u); + ASSERT_TRUE(logicalTypeInfo.has_value()); + ASSERT_EQ(logicalTypeInfo->source, ArrowLogicalTypeInfo::Source::GENERIC_METADATA); + schema.release(&schema); +} + +TEST(ArrowConverterTest, malformedIntegerBackedDecimalMetadataFallsBackToPhysicalType) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = + serializeArrowMetadata({{"logicalType", "FIXED"}, {"precision", "bad"}, {"scale", "2"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::INT64); + schema.release(&schema); +} + +TEST(ArrowConverterTest, negativeSnowflakeDecimalMetadataFallsBackToPhysicalType) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = + serializeArrowMetadata({{"logicalType", "FIXED"}, {"precision", "-1"}, {"scale", "2"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::INT64); + schema.release(&schema); +} + +TEST(ArrowConverterTest, overflowSnowflakeDecimalMetadataFallsBackToPhysicalType) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = serializeArrowMetadata( + {{"logicalType", "FIXED"}, {"precision", "4294967296"}, {"scale", "2"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::INT64); + schema.release(&schema); +} + +TEST(ArrowConverterTest, invalidSnowflakeDecimalScaleFallsBackToPhysicalType) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = + serializeArrowMetadata({{"logicalType", "FIXED"}, {"precision", "4"}, {"scale", "5"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::INT64); + schema.release(&schema); +} + +TEST(ArrowConverterTest, negativeMetadataEntryCountFallsBackToPhysicalType) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + std::vector metadata; + appendInt32(metadata, -1); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::INT64); + ASSERT_FALSE(tryGetArrowLogicalTypeInfo(&schema).has_value()); + schema.release(&schema); +} + +TEST(ArrowConverterTest, negativeMetadataKeyLengthFallsBackToPhysicalType) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + std::vector metadata; + appendInt32(metadata, 1); + appendInt32(metadata, -1); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::INT64); + ASSERT_FALSE(tryGetArrowLogicalTypeInfo(&schema).has_value()); + schema.release(&schema); +} + +TEST(ArrowConverterTest, negativeMetadataValueLengthFallsBackToPhysicalType) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + std::vector metadata; + appendInt32(metadata, 1); + appendString(metadata, "logicalType"); + appendInt32(metadata, -1); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::INT64); + ASSERT_FALSE(tryGetArrowLogicalTypeInfo(&schema).has_value()); + schema.release(&schema); +} + +TEST(ArrowConverterTest, invalidGenericDecimalScaleFallsBackToPhysicalType) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = + serializeArrowMetadata({{"logicalType", "DECIMAL"}, {"precision", "6"}, {"scale", "7"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::INT64); + ASSERT_FALSE(tryGetArrowLogicalTypeInfo(&schema).has_value()); + schema.release(&schema); +} + +TEST(ArrowConverterTest, genericDecimalMetadataDoesNotBindFloatBackedStorage) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = + serializeArrowMetadata({{"logicalType", "DECIMAL"}, {"precision", "9"}, {"scale", "2"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::DOUBLE); + ASSERT_FALSE(tryGetArrowLogicalTypeInfo(&schema).has_value()); + schema.release(&schema); +} + +TEST(ArrowConverterTest, malformedSnowflakeRawDataTypeFallsBackToSnowflakeLogicalTypeMetadata) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = serializeArrowMetadata({{"DATA_TYPE", "NUMBER(bad,2)"}, + {"logicalType", "FIXED"}, {"precision", "7"}, {"scale", "2"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + const auto logicalTypeInfo = tryGetArrowLogicalTypeInfo(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::DECIMAL); + ASSERT_EQ(DecimalType::getPrecision(type), 7u); + ASSERT_EQ(DecimalType::getScale(type), 2u); + ASSERT_TRUE(logicalTypeInfo.has_value()); + ASSERT_EQ(logicalTypeInfo->source, ArrowLogicalTypeInfo::Source::SNOWFLAKE); + schema.release(&schema); +} + +TEST(ArrowConverterTest, malformedSnowflakeRawDataTypeFallsBackToGenericLogicalTypeMetadata) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = serializeArrowMetadata({{"DATA_TYPE", "NUMBER(bad,2)"}, + {"logicalType", "DECIMAL"}, {"precision", "7"}, {"scale", "2"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + const auto logicalTypeInfo = tryGetArrowLogicalTypeInfo(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::DECIMAL); + ASSERT_EQ(DecimalType::getPrecision(type), 7u); + ASSERT_EQ(DecimalType::getScale(type), 2u); + ASSERT_TRUE(logicalTypeInfo.has_value()); + ASSERT_EQ(logicalTypeInfo->source, ArrowLogicalTypeInfo::Source::GENERIC_METADATA); + schema.release(&schema); +} + +TEST(ArrowConverterTest, + snowflakeRawDataTypeMetadataTakesPrecedenceOverGenericLogicalTypeMetadata) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = serializeArrowMetadata({{"DATA_TYPE", "NUMBER(12,4)"}, + {"logicalType", "DECIMAL"}, {"precision", "9"}, {"scale", "3"}}); + schema.metadata = metadata.data(); + + const auto type = ArrowConverter::fromArrowSchema(&schema); + const auto logicalTypeInfo = tryGetArrowLogicalTypeInfo(&schema); + + ASSERT_EQ(type.getLogicalTypeID(), LogicalTypeID::DECIMAL); + ASSERT_EQ(DecimalType::getPrecision(type), 12u); + ASSERT_EQ(DecimalType::getScale(type), 4u); + ASSERT_TRUE(logicalTypeInfo.has_value()); + ASSERT_EQ(logicalTypeInfo->source, ArrowLogicalTypeInfo::Source::SNOWFLAKE); + schema.release(&schema); +} + +TEST(ArrowConverterTest, scansFloatBackedSnowflakeDecimalMetadataIntoDecimalStorage) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = + serializeArrowMetadata({{"logicalType", "FIXED"}, {"precision", "9"}, {"scale", "2"}}); + schema.metadata = metadata.data(); + + ArrowArray array{}; + createDoubleArray(&array, {1.25, -2.50, 100.01}); + ValueVector outputVector{LogicalType::DECIMAL(9, 2)}; + + ArrowConverter::fromArrowArray(&schema, &array, outputVector); + + ASSERT_EQ(outputVector.getValue(0), 125); + ASSERT_EQ(outputVector.getValue(1), -250); + ASSERT_EQ(outputVector.getValue(2), 10001); + array.release(&array); + schema.release(&schema); +} + +TEST(ArrowConverterTest, scansFloat32BackedSnowflakeDecimalMetadataIntoDecimalStorage) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = + serializeArrowMetadata({{"logicalType", "FIXED"}, {"precision", "8"}, {"scale", "1"}}); + schema.metadata = metadata.data(); + + ArrowArray array{}; + createFloatArray(&array, {1.5F, -2.0F, 12.3F}); + ValueVector outputVector{LogicalType::DECIMAL(8, 1)}; + + ArrowConverter::fromArrowArray(&schema, &array, outputVector); + + ASSERT_EQ(outputVector.getValue(0), 15); + ASSERT_EQ(outputVector.getValue(1), -20); + ASSERT_EQ(outputVector.getValue(2), 123); + array.release(&array); + schema.release(&schema); +} + +TEST(ArrowConverterTest, scansIntegerBackedSnowflakeDecimalMetadataIntoDecimalStorage) { + ArrowSchema schema{}; + createSchema(&schema, "amount"); + auto metadata = + serializeArrowMetadata({{"logicalType", "FIXED"}, {"precision", "7"}, {"scale", "2"}}); + schema.metadata = metadata.data(); + + ArrowArray array{}; + createInt64Array(&array, {120, 200, 50, 10000}); + ValueVector outputVector{LogicalType::DECIMAL(7, 2)}; + + ArrowConverter::fromArrowArray(&schema, &array, outputVector); + + ASSERT_EQ(outputVector.getValue(0), 120); + ASSERT_EQ(outputVector.getValue(1), 200); + ASSERT_EQ(outputVector.getValue(2), 50); + ASSERT_EQ(outputVector.getValue(3), 10000); + array.release(&array); + schema.release(&schema); +} + TEST_F(ArrowTest, resultToArrow) { auto query = "MATCH (a:person) WHERE a.fName = 'Bob' RETURN a.fName"; auto result = conn->query(query); diff --git a/test/include/arrow_test_utils.h b/test/include/arrow_test_utils.h index 6ee9076c14..dd954e9869 100644 --- a/test/include/arrow_test_utils.h +++ b/test/include/arrow_test_utils.h @@ -50,6 +50,19 @@ inline void createSchema(ArrowSchema* schema, const char* name) { schema->private_data = nullptr; } +template<> +inline void createSchema(ArrowSchema* schema, const char* name) { + schema->format = "f"; // float + schema->name = name; + schema->metadata = nullptr; + schema->flags = ARROW_FLAG_NULLABLE; + schema->n_children = 0; + schema->children = nullptr; + schema->dictionary = nullptr; + schema->release = [](ArrowSchema* s) { s->release = nullptr; }; + schema->private_data = nullptr; +} + template<> inline void createSchema(ArrowSchema* schema, const char* name) { schema->format = "b"; // boolean @@ -315,6 +328,54 @@ inline void createDoubleArray(ArrowArray* array, const std::vector& data array->private_data = private_data; } +// Helper to create a float array from vector +inline void createFloatArray(ArrowArray* array, const std::vector& data) { + struct ArrayPrivateData { + void* validity = nullptr; + void* data = nullptr; + int32_t* offsets = nullptr; + }; + + auto* private_data = new ArrayPrivateData(); + private_data->validity = nullptr; // No nulls + private_data->data = malloc(data.size() * sizeof(float)); + memcpy(private_data->data, data.data(), data.size() * sizeof(float)); + + array->length = data.size(); + array->null_count = 0; + array->offset = 0; + array->n_buffers = 2; // validity and data + array->n_children = 0; + array->buffers = static_cast(malloc(sizeof(void*) * 2)); + array->buffers[0] = nullptr; // validity buffer (no nulls) + array->buffers[1] = private_data->data; + array->children = nullptr; + array->dictionary = nullptr; + array->release = [](ArrowArray* a) { + if (a->private_data) { + auto* pd = static_cast(a->private_data); + free(pd->validity); + free(pd->data); + free(pd->offsets); + delete pd; + } + if (a->buffers) { + free(const_cast(a->buffers)); + } + if (a->children) { + for (int64_t i = 0; i < a->n_children; i++) { + if (a->children[i]->release) { + a->children[i]->release(a->children[i]); + } + free(a->children[i]); + } + free(a->children); + } + a->release = nullptr; + }; + array->private_data = private_data; +} + template<> inline void createSchema(ArrowSchema* schema, const char* name) { schema->format = "L"; // uint64 (capital L)