diff --git a/include/neug/common/extra_type_info.h b/include/neug/common/extra_type_info.h index 6f8533926..dfd21785c 100644 --- a/include/neug/common/extra_type_info.h +++ b/include/neug/common/extra_type_info.h @@ -33,6 +33,7 @@ enum class ExtraTypeInfoType : uint8_t { STRING_TYPE_INFO = 3, LIST_TYPE_INFO = 4, STRUCT_TYPE_INFO = 5, + ARRAY_TYPE_INFO = 6, }; struct ExtraTypeInfo { @@ -84,6 +85,15 @@ struct ListTypeInfo : public ExtraTypeInfo { bool EqualsInternal(ExtraTypeInfo* other_p) const override; }; +struct ArrayTypeInfo : public ExtraTypeInfo { + DataType child_type; + uint32_t array_size; + ArrayTypeInfo(DataType child_type_p, uint32_t array_size_p); + + protected: + bool EqualsInternal(ExtraTypeInfo* other_p) const override; +}; + struct StringTypeInfo : public ExtraTypeInfo { size_t max_length; explicit StringTypeInfo(size_t length) diff --git a/include/neug/common/types.h b/include/neug/common/types.h index f97ba7b47..4126e18ca 100644 --- a/include/neug/common/types.h +++ b/include/neug/common/types.h @@ -64,7 +64,7 @@ enum class DataTypeId : uint8_t { kList = 101, // kMap = 102, - // kArray = 108, + kArray = 108, kVertex = 200, kEdge = 201, @@ -113,6 +113,7 @@ struct DataType { static DataType Struct(std::vector children); static DataType List(const DataType& child_type); + static DataType Array(const DataType& child_type, uint32_t array_size); static DataType Varchar(size_t max_length); inline DataTypeId id() const { return id_; } @@ -175,11 +176,16 @@ struct DataType { static constexpr const DataTypeId VERTEX = DataTypeId::kVertex; static constexpr const DataTypeId EDGE = DataTypeId::kEdge; static constexpr const DataTypeId PATH = DataTypeId::kPath; + static constexpr const DataTypeId ARRAY = DataTypeId::kArray; }; struct ListType { static const DataType& GetChildType(const DataType& type); }; +struct ArrayType { + static const DataType& GetChildType(const DataType& type); + static uint32_t GetSize(const DataType& type); +}; struct StructType { static const std::vector& GetChildTypes(const DataType& type); static const DataType& GetChildType(const DataType& type, size_t index); diff --git a/include/neug/execution/common/types/value.h b/include/neug/execution/common/types/value.h index 8f78b846c..f6304892e 100644 --- a/include/neug/execution/common/types/value.h +++ b/include/neug/execution/common/types/value.h @@ -42,6 +42,7 @@ class Value { friend struct StringValue; friend struct StructValue; friend struct ListValue; + friend struct ArrayValue; friend struct PathValue; public: @@ -81,6 +82,8 @@ class Value { static Value LIST(const DataType& child_type, std::vector&& values); static Value LIST(std::vector&& values); + static Value ARRAY(const DataType& array_type, std::vector&& values); + static Value STRING(const std::string& str); static Value VARCHAR(const std::string& str, uint16_t max_length); @@ -168,6 +171,11 @@ struct ListValue { static const std::vector& GetChildren(const Value& value); }; +struct ArrayValue { + static const std::vector& GetChildren(const Value& value); + static uint32_t GetSize(const Value& value); +}; + struct StructValue { static const std::vector& GetChildren(const Value& value); }; @@ -616,6 +624,17 @@ bool Value::ApplyComparisonOp(const Value& lhs, const Value& rhs) { } return true; } + case DataTypeId::kArray: { + const auto& lhs_children = ArrayValue::GetChildren(lhs); + const auto& rhs_children = ArrayValue::GetChildren(rhs); + assert(lhs_children.size() == rhs_children.size()); + for (size_t i = 0; i < lhs_children.size(); ++i) { + if (!ApplyComparisonOp(lhs_children[i], rhs_children[i])) { + return false; + } + } + return true; + } case DataTypeId::kVertex: { return OP::operation(lhs.GetValue(), rhs.GetValue()); } diff --git a/include/neug/utils/property/array_column.h b/include/neug/utils/property/array_column.h new file mode 100644 index 000000000..98ae6fb4e --- /dev/null +++ b/include/neug/utils/property/array_column.h @@ -0,0 +1,75 @@ +/** Copyright 2020 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include "neug/common/extra_type_info.h" +#include "neug/common/types.h" +#include "neug/execution/common/types/value.h" +#include "neug/utils/property/column.h" + +namespace neug { + +/** + * @brief Fixed-length array column backed by a child column. + * + * For row i, element j is stored at child_column[i * array_size + j]. + * The ArrayColumn itself has no data buffer; it stores only metadata + * (array_type, array_size, row count) in its module descriptor. + * The child column handles actual storage. + */ +class ArrayColumn : public ColumnBase { + public: + ArrayColumn() : array_size_(0), size_(0) {} + explicit ArrayColumn(const DataType& array_type); + ~ArrayColumn() override = default; + + void Open(Checkpoint& ckp, const ModuleDescriptor& desc, + MemoryLevel level) override; + + ModuleDescriptor Dump(Checkpoint& ckp) override; + + size_t size() const override { return size_; } + + void resize(size_t size) override; + void resize(size_t size, const Property& default_value) override; + + DataTypeId type() const override { return DataTypeId::kArray; } + + void set_any(size_t index, const Property& value, + bool insert_safe) override; + + Property get_prop(size_t index) const override; + + void set_value(size_t index, const execution::Value& value); + + execution::Value get_value(size_t index) const; + + void ingest(uint32_t index, OutArchive& arc) override; + + std::string ModuleTypeName() const override { return type_name(); } + + static std::string type_name() { return "column"; } + + private: + DataType array_type_; + uint32_t array_size_; + size_t size_; + std::unique_ptr child_column_; +}; + +} // namespace neug diff --git a/include/neug/utils/property/types.h b/include/neug/utils/property/types.h index ab17fffcd..af893a0e9 100644 --- a/include/neug/utils/property/types.h +++ b/include/neug/utils/property/types.h @@ -585,6 +585,16 @@ struct convert { THROW_NOT_SUPPORTED_EXCEPTION("Unrecognized temporal type: " + temporal.as()); } + } else if (config["array"]) { + auto array_node = config["array"]; + neug::DataType child_type; + if (!array_node["component_type"] || + !decode(array_node["component_type"], child_type)) { + LOG(ERROR) << "Failed to parse array component_type"; + return false; + } + uint32_t max_length = array_node["max_length"].as(); + property_type = neug::DataType::Array(child_type, max_length); } else if (config["date"]) { property_type = neug::DataTypeId::kDate; } else { @@ -612,6 +622,11 @@ struct convert { : neug::STRING_DEFAULT_MAX_LENGTH; } else if (type == neug::DataTypeId::kDate) { node["temporal"]["datetime"] = ""; + } else if (type == neug::DataTypeId::kArray) { + auto child_type = neug::ArrayType::GetChildType(type); + uint32_t array_size = neug::ArrayType::GetSize(type); + node["array"]["component_type"] = encode(child_type); + node["array"]["max_length"] = array_size; } else { LOG(ERROR) << "Unrecognized property type: " << type.ToString(); } diff --git a/src/common/extra_type_info.cc b/src/common/extra_type_info.cc index da812bbd0..f9525ca86 100644 --- a/src/common/extra_type_info.cc +++ b/src/common/extra_type_info.cc @@ -84,6 +84,16 @@ bool ListTypeInfo::EqualsInternal(ExtraTypeInfo* other_p) const { return child_type == other.child_type; } +ArrayTypeInfo::ArrayTypeInfo(DataType child_type_p, uint32_t array_size_p) + : ExtraTypeInfo(ExtraTypeInfoType::ARRAY_TYPE_INFO), + child_type(std::move(child_type_p)), + array_size(array_size_p) {} + +bool ArrayTypeInfo::EqualsInternal(ExtraTypeInfo* other_p) const { + auto& other = other_p->Cast(); + return child_type == other.child_type && array_size == other.array_size; +} + bool StringTypeInfo::EqualsInternal(ExtraTypeInfo* other_p) const { auto& other = other_p->Cast(); return max_length == other.max_length; diff --git a/src/common/types.cc b/src/common/types.cc index d4f172475..21d109a40 100644 --- a/src/common/types.cc +++ b/src/common/types.cc @@ -95,6 +95,26 @@ DataType DataType::List(const DataType& child_type) { return DataType(DataTypeId::kList, type_info); } +DataType DataType::Array(const DataType& child_type, uint32_t array_size) { + std::shared_ptr type_info = + std::make_shared(child_type, array_size); + return DataType(DataTypeId::kArray, type_info); +} + +const DataType& ArrayType::GetChildType(const DataType& type) { + assert(type.id() == DataTypeId::kArray); + auto info = type.RawExtraTypeInfo(); + assert(info); + return info->Cast().child_type; +} + +uint32_t ArrayType::GetSize(const DataType& type) { + assert(type.id() == DataTypeId::kArray); + auto info = type.RawExtraTypeInfo(); + assert(info); + return info->Cast().array_size; +} + DataType DataType::Varchar(size_t max_length) { std::shared_ptr type_info = std::make_shared(max_length); @@ -147,10 +167,14 @@ DataType parse_from_data_type(const ::common::DataType& ddt) { } } case ::common::DataType::kArray: { - const auto& element_type = ddt.array().component_type(); - auto data_type = parse_from_data_type(element_type); - return DataType(DataTypeId::kList, - std::make_shared(data_type)); + const auto& array_pb = ddt.array(); + const auto& element_type = array_pb.component_type(); + auto child_data_type = parse_from_data_type(element_type); + uint32_t max_length = array_pb.max_length(); + if (max_length > 0) { + return DataType::Array(child_data_type, max_length); + } + return DataType::List(child_data_type); } case ::common::DataType::kTuple: { const auto& component_types = ddt.tuple().component_types(); @@ -255,6 +279,12 @@ std::string DataType::ToString() const { const DataType& child_type = ListType::GetChildType(*this); return "LIST<" + child_type.ToString() + ">"; } + case DataTypeId::kArray: { + const DataType& child_type = ArrayType::GetChildType(*this); + uint32_t array_size = ArrayType::GetSize(*this); + return "ARRAY<" + child_type.ToString() + ", " + + std::to_string(array_size) + ">"; + } case DataTypeId::kStruct: { const auto& child_types = StructType::GetChildTypes(*this); std::string result = "STRUCT<"; diff --git a/src/compiler/binder/bind/read/bind_unwind.cpp b/src/compiler/binder/bind/read/bind_unwind.cpp index cf49281d0..4749d624c 100644 --- a/src/compiler/binder/bind/read/bind_unwind.cpp +++ b/src/compiler/binder/bind/read/bind_unwind.cpp @@ -72,7 +72,7 @@ std::unique_ptr Binder::bindUnwindClause( if (boundExpression->getDataType().getLogicalTypeID() == LogicalTypeID::ARRAY) { auto targetType = LogicalType::LIST( - ArrayType::getChildType(boundExpression->dataType).copy()); + common::ArrayType::getChildType(boundExpression->dataType).copy()); boundExpression = expressionBinder.implicitCast(boundExpression, targetType); } diff --git a/src/compiler/function/vector_cast_functions.cpp b/src/compiler/function/vector_cast_functions.cpp index d013f94fd..65be1e68c 100644 --- a/src/compiler/function/vector_cast_functions.cpp +++ b/src/compiler/function/vector_cast_functions.cpp @@ -182,24 +182,24 @@ static bool hasImplicitCastList(const LogicalType& srcType, static bool hasImplicitCastArray(const LogicalType& srcType, const LogicalType& dstType) { - if (ArrayType::getNumElements(srcType) != - ArrayType::getNumElements(dstType)) { + if (common::ArrayType::getNumElements(srcType) != + common::ArrayType::getNumElements(dstType)) { return false; } - return CastFunction::hasImplicitCast(ArrayType::getChildType(srcType), - ArrayType::getChildType(dstType)); + return CastFunction::hasImplicitCast(common::ArrayType::getChildType(srcType), + common::ArrayType::getChildType(dstType)); } static bool hasImplicitCastArrayToList(const LogicalType& srcType, const LogicalType& dstType) { - return CastFunction::hasImplicitCast(ArrayType::getChildType(srcType), + return CastFunction::hasImplicitCast(common::ArrayType::getChildType(srcType), ::ListType::getChildType(dstType)); } static bool hasImplicitCastListToArray(const LogicalType& srcType, const LogicalType& dstType) { return CastFunction::hasImplicitCast(::ListType::getChildType(srcType), - ArrayType::getChildType(dstType)); + common::ArrayType::getChildType(dstType)); } static bool hasImplicitCastStruct(const LogicalType& srcType, diff --git a/src/compiler/gopt/g_type_converter.cpp b/src/compiler/gopt/g_type_converter.cpp index 222d40b69..83d35e7ac 100644 --- a/src/compiler/gopt/g_type_converter.cpp +++ b/src/compiler/gopt/g_type_converter.cpp @@ -404,6 +404,17 @@ GPhysicalTypeConverter::convertSimpleLogicalType( result->set_allocated_temporal(temporalType.release()); break; } + case common::LogicalTypeID::ARRAY: { + auto& childType = common::ArrayType::getChildType(type); + auto numElements = common::ArrayType::getNumElements(type); + auto childIrType = convertSimpleLogicalType(childType); + auto arrayPb = std::make_unique<::common::Array>(); + arrayPb->set_allocated_component_type( + childIrType->release_data_type()); + arrayPb->set_max_length(numElements); + result->set_allocated_array(arrayPb.release()); + break; + } default: THROW_EXCEPTION_WITH_FILE_LINE("Unsupported basic type for conversion: " + type.toString()); diff --git a/src/execution/common/columns/columns_utils.cc b/src/execution/common/columns/columns_utils.cc index 1faca69a9..46e332eb8 100644 --- a/src/execution/common/columns/columns_utils.cc +++ b/src/execution/common/columns/columns_utils.cc @@ -45,6 +45,10 @@ std::shared_ptr ColumnsUtils::create_builder( DataType elem_type = ListType::GetChildType(type); return std::make_shared(elem_type); } + case DataTypeId::kArray: { + DataType elem_type = ArrayType::GetChildType(type); + return std::make_shared(elem_type); + } case DataTypeId::kPath: { return std::make_shared(); } diff --git a/src/execution/common/types/value.cc b/src/execution/common/types/value.cc index bf7805a8f..7e4f1e44b 100644 --- a/src/execution/common/types/value.cc +++ b/src/execution/common/types/value.cc @@ -233,6 +233,13 @@ Value Value::LIST(std::vector&& values) { return Value::LIST(type, std::move(values)); } +Value Value::ARRAY(const DataType& array_type, std::vector&& values) { + Value result(array_type); + result.value_info_ = std::make_shared(std::move(values)); + result.is_null_ = false; + return result; +} + Value Value::STRUCT(const DataType& type, std::vector&& struct_values) { Value result(type); result.value_info_ = @@ -301,6 +308,19 @@ const std::vector& ListValue::GetChildren(const Value& value) { return value.value_info_->Get().GetValues(); } +const std::vector& ArrayValue::GetChildren(const Value& value) { + if (value.IsNull()) { + throw std::runtime_error("Cannot get children of null ArrayValue"); + } + assert(value.type().id() == DataTypeId::kArray); + return value.value_info_->Get().GetValues(); +} + +uint32_t ArrayValue::GetSize(const Value& value) { + assert(value.type().id() == DataTypeId::kArray); + return ArrayType::GetSize(value.type()); +} + const std::vector& StructValue::GetChildren(const Value& value) { if (value.IsNull()) { throw std::runtime_error("Cannot get children of null StructValue"); @@ -590,6 +610,16 @@ std::string Value::to_string() const { } } return result + "]"; + } else if (type_.id() == DataTypeId::kArray) { + const auto& children = ArrayValue::GetChildren(*this); + std::string result = "["; + for (size_t i = 0; i < children.size(); ++i) { + result += children[i].to_string(); + if (i != children.size() - 1) { + result += ", "; + } + } + return result + "]"; } else if (type_.id() == DataTypeId::kStruct) { const auto& children = StructValue::GetChildren(*this); std::string result = "("; @@ -681,6 +711,18 @@ Value Value::FromJson(const rapidjson::Value& json_value, } return execution::Value::LIST(child_type, std::move(values)); } + case DataTypeId::kArray: { + std::vector values; + if (!json_value.IsArray()) { + return execution::Value::ARRAY(type, std::move(values)); + } + const auto arr = json_value.GetArray(); + auto child_type = ArrayType::GetChildType(type); + for (auto item = arr.begin(); item != arr.end(); ++item) { + values.emplace_back(FromJson(*item, child_type)); + } + return execution::Value::ARRAY(type, std::move(values)); + } default: THROW_NOT_IMPLEMENTED_EXCEPTION( "Deserialization for parameter type " + @@ -723,6 +765,14 @@ rapidjson::Value Value::ToJson(const Value& value, } return list_doc; } + case neug::DataTypeId::kArray: { + rapidjson::Value array_doc(rapidjson::kArrayType); + const auto& elements = execution::ArrayValue::GetChildren(value); + for (size_t i = 0; i < elements.size(); ++i) { + array_doc.PushBack(ToJson(elements[i], allocator), allocator); + } + return array_doc; + } case neug::DataTypeId::kDate: { return rapidjson::Value(value.GetValue().to_string().c_str(), allocator); @@ -840,6 +890,11 @@ void encode_value(const Value& val, Encoder& encoder) { const auto& vals = ListValue::GetChildren(val); encoder.put_int(vals.size()); + for (const auto& v : vals) { + encode_value(v, encoder); + } + } else if (type.id() == DataTypeId::kArray) { + const auto& vals = ArrayValue::GetChildren(val); for (const auto& v : vals) { encode_value(v, encoder); } diff --git a/src/storages/graph/schema.cc b/src/storages/graph/schema.cc index 7366406ec..d502d4891 100644 --- a/src/storages/graph/schema.cc +++ b/src/storages/graph/schema.cc @@ -26,6 +26,7 @@ #include #include #include +#include "neug/common/extra_type_info.h" #include "neug/utils/exception/exception.h" #include "neug/utils/id_indexer.h" #include "neug/utils/pb_utils.h" @@ -2298,6 +2299,9 @@ InArchive& operator<<(InArchive& in_archive, const DataType& type) { for (const auto& child_type : child_types) { in_archive << child_type; } + } else if (id == DataTypeId::kArray) { + const auto& array_type_info = type_info->Cast(); + in_archive << array_type_info.child_type << array_type_info.array_size; } else if (id == DataTypeId::kVarchar) { const auto& varchar_type_info = type_info->Cast(); in_archive << varchar_type_info.max_length; @@ -2330,6 +2334,11 @@ OutArchive& operator>>(OutArchive& out_archive, DataType& type) { out_archive >> child_types[i]; } type = DataType::Struct(child_types); + } else if (id == DataTypeId::kArray) { + DataType child_type; + uint32_t array_size; + out_archive >> child_type >> array_size; + type = DataType::Array(child_type, array_size); } else if (id == DataTypeId::kVarchar) { size_t max_length; out_archive >> max_length; diff --git a/src/utils/pb_utils.cc b/src/utils/pb_utils.cc index 1cd5fa0d6..2811563bf 100644 --- a/src/utils/pb_utils.cc +++ b/src/utils/pb_utils.cc @@ -231,8 +231,19 @@ bool data_type_to_property_type(const common::DataType& data_type, return temporal_type_to_property_type(data_type.temporal(), out_type); } case common::DataType::kArray: { - LOG(ERROR) << "Array type is not supported"; - return false; + const auto& array = data_type.array(); + DataType child_type; + if (!data_type_to_property_type(array.component_type(), child_type)) { + LOG(ERROR) << "Failed to parse array component type"; + return false; + } + uint32_t max_length = array.max_length(); + if (max_length > 0) { + out_type = DataType::Array(child_type, max_length); + } else { + out_type = DataType::List(child_type); + } + return true; } case common::DataType::kMap: { LOG(ERROR) << "Map type is not supported"; @@ -299,6 +310,37 @@ bool common_value_to_value(const DataType& type, const common::Value& value, case common::Value::kDate: out_value = execution::Value::DATE(Date(value.date().item())); break; + case common::Value::kI32Array: { + const auto& arr = value.i32_array(); + auto child_type = ArrayType::GetChildType(type); + std::vector elements; + elements.reserve(arr.item_size()); + for (int i = 0; i < arr.item_size(); ++i) { + elements.emplace_back(execution::Value::INT32(arr.item(i))); + } + out_value = execution::Value::ARRAY(type, std::move(elements)); + break; + } + case common::Value::kI64Array: { + const auto& arr = value.i64_array(); + std::vector elements; + elements.reserve(arr.item_size()); + for (int i = 0; i < arr.item_size(); ++i) { + elements.emplace_back(execution::Value::INT64(arr.item(i))); + } + out_value = execution::Value::ARRAY(type, std::move(elements)); + break; + } + case common::Value::kF64Array: { + const auto& arr = value.f64_array(); + std::vector elements; + elements.reserve(arr.item_size()); + for (int i = 0; i < arr.item_size(); ++i) { + elements.emplace_back(execution::Value::DOUBLE(arr.item(i))); + } + out_value = execution::Value::ARRAY(type, std::move(elements)); + break; + } default: LOG(ERROR) << "Unknown value type: " << value.DebugString(); return false; diff --git a/src/utils/property/array_column.cc b/src/utils/property/array_column.cc new file mode 100644 index 000000000..88c6c9576 --- /dev/null +++ b/src/utils/property/array_column.cc @@ -0,0 +1,139 @@ +/** Copyright 2020 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "neug/utils/property/array_column.h" + +#include + +#include "neug/storages/module/module_factory.h" +#include "neug/utils/exception/exception.h" +#include "neug/utils/serialization/out_archive.h" + +namespace neug { + +ArrayColumn::ArrayColumn(const DataType& array_type) + : array_type_(array_type), + array_size_(ArrayType::GetSize(array_type)), + size_(0) { + auto child_type = ArrayType::GetChildType(array_type); + child_column_ = CreateColumn(child_type); +} + +void ArrayColumn::Open(Checkpoint& ckp, const ModuleDescriptor& desc, + MemoryLevel level) { + auto size_str = desc.get("array_row_count"); + if (size_str.has_value()) { + size_ = std::stoull(size_str.value()); + } else { + size_ = 0; + } + + auto child_desc_json = desc.get("child_descriptor"); + ModuleDescriptor child_desc; + if (child_desc_json.has_value()) { + rapidjson::Document doc; + doc.Parse(child_desc_json.value().c_str()); + if (!doc.HasParseError() && doc.IsObject()) { + child_desc = ModuleDescriptor::FromJson(doc); + } + } + child_column_->Open(ckp, child_desc, level); +} + +ModuleDescriptor ArrayColumn::Dump(Checkpoint& ckp) { + ModuleDescriptor desc; + desc.module_type = ModuleTypeName(); + desc.set("array_row_count", std::to_string(size_)); + desc.set("array_size", std::to_string(array_size_)); + + auto child_desc = child_column_->Dump(ckp); + desc.set("child_descriptor", child_desc.ToJsonString()); + return desc; +} + +void ArrayColumn::resize(size_t size) { + size_ = size; + child_column_->resize(size_ * array_size_); +} + +void ArrayColumn::resize(size_t size, const Property& default_value) { + size_ = size; + child_column_->resize(size_ * array_size_); +} + +void ArrayColumn::set_any(size_t index, const Property& value, + bool insert_safe) { + THROW_NOT_SUPPORTED_EXCEPTION( + "ArrayColumn does not support set_any via Property. Use set_value with " + "Value instead."); +} + +Property ArrayColumn::get_prop(size_t index) const { + THROW_NOT_SUPPORTED_EXCEPTION( + "ArrayColumn does not support get_prop. Use get_value instead."); +} + +void ArrayColumn::set_value(size_t index, + const execution::Value& value) { + if (index >= size_) { + THROW_RUNTIME_ERROR("ArrayColumn::set_value: index " + + std::to_string(index) + " out of range (size=" + + std::to_string(size_) + ")"); + } + const auto& children = execution::ListValue::GetChildren(value); + if (children.size() != array_size_) { + THROW_INVALID_ARGUMENT_EXCEPTION( + "ArrayColumn::set_value: expected " + std::to_string(array_size_) + + " elements, got " + std::to_string(children.size())); + } + size_t base = index * array_size_; + for (uint32_t j = 0; j < array_size_; ++j) { + auto prop = execution::value_to_property(children[j]); + child_column_->set_any(base + j, prop, false); + } +} + +execution::Value ArrayColumn::get_value(size_t index) const { + if (index >= size_) { + THROW_RUNTIME_ERROR("ArrayColumn::get_value: index " + + std::to_string(index) + " out of range (size=" + + std::to_string(size_) + ")"); + } + auto child_type = ArrayType::GetChildType(array_type_); + std::vector values; + values.reserve(array_size_); + size_t base = index * array_size_; + for (uint32_t j = 0; j < array_size_; ++j) { + values.emplace_back(execution::property_to_value( + child_column_->get_prop(base + j))); + } + return execution::Value::ARRAY(array_type_, std::move(values)); +} + +void ArrayColumn::ingest(uint32_t index, OutArchive& arc) { + if (index >= size_) { + THROW_RUNTIME_ERROR("ArrayColumn::ingest: index " + + std::to_string(index) + " out of range (size=" + + std::to_string(size_) + ")"); + } + size_t base = index * array_size_; + for (uint32_t j = 0; j < array_size_; ++j) { + child_column_->ingest(base + j, arc); + } +} + +NEUG_REGISTER_MODULE(ArrayColumn); + +} // namespace neug diff --git a/src/utils/property/column.cc b/src/utils/property/column.cc index 63e08323c..056bf9f49 100644 --- a/src/utils/property/column.cc +++ b/src/utils/property/column.cc @@ -20,6 +20,7 @@ #include "neug/storages/container/container_utils.h" #include "neug/storages/module/module_factory.h" #include "neug/utils/id_indexer.h" +#include "neug/utils/property/array_column.h" #include "neug/utils/property/table.h" #include "neug/utils/property/types.h" #include "neug/utils/serialization/out_archive.h" @@ -72,6 +73,9 @@ std::unique_ptr CreateColumn(DataType type) { } return std::make_unique(max_length); } + case DataTypeId::kArray: { + return std::make_unique(type); + } case DataTypeId::kEmpty: { return std::make_unique>(); } diff --git a/src/utils/property/property.cc b/src/utils/property/property.cc index cfd0dcd06..162ff5508 100644 --- a/src/utils/property/property.cc +++ b/src/utils/property/property.cc @@ -14,6 +14,7 @@ */ #include "neug/utils/property/property.h" +#include "neug/common/extra_type_info.h" #include "neug/execution/common/types/value.h" #include "neug/utils/serialization/in_archive.h" #include "neug/utils/serialization/out_archive.h" @@ -51,6 +52,16 @@ execution::Value get_default_value(const DataType& type) { return execution::Value::TIMESTAMPMS(DateTime(0)); case DataTypeId::kInterval: return execution::Value::INTERVAL(Interval()); + case DataTypeId::kArray: { + auto child_type = ArrayType::GetChildType(type); + uint32_t array_size = ArrayType::GetSize(type); + std::vector default_elements; + default_elements.reserve(array_size); + for (uint32_t i = 0; i < array_size; ++i) { + default_elements.emplace_back(get_default_value(child_type)); + } + return execution::Value::ARRAY(type, std::move(default_elements)); + } default: THROW_NOT_SUPPORTED_EXCEPTION( "Unsupported property type for default value: " + type.ToString()); diff --git a/src/utils/yaml_utils.cc b/src/utils/yaml_utils.cc index e6ca90f3f..a82914b3c 100644 --- a/src/utils/yaml_utils.cc +++ b/src/utils/yaml_utils.cc @@ -33,6 +33,7 @@ #include #include +#include "neug/common/extra_type_info.h" #include "neug/utils/exception/exception.h" #include "neug/utils/property/types.h" #include "neug/utils/result.h" @@ -76,6 +77,13 @@ YAML::Node property_type_to_yaml(const DataType& type) { case DataTypeId::kInterval: node["temporal"] = config_parsing::TemporalTypeToYAML(type.id()); break; + case DataTypeId::kArray: { + auto child_type = ArrayType::GetChildType(type); + uint32_t array_size = ArrayType::GetSize(type); + node["array"]["component_type"] = property_type_to_yaml(child_type); + node["array"]["max_length"] = array_size; + break; + } default: THROW_INVALID_ARGUMENT_EXCEPTION( "Unrecognized property type for YAML encoding: " + type.ToString()); diff --git a/tools/python_bind/tests/test_db_array.py b/tools/python_bind/tests/test_db_array.py new file mode 100644 index 000000000..1197d42ba --- /dev/null +++ b/tools/python_bind/tests/test_db_array.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Integration tests for Array property type.""" + +import pytest + +from neug.database import Database + + +def test_array_int32_create_and_query(tmp_path): + """Create a vertex type with INT32[3] array property, insert and query.""" + db_dir = tmp_path / "array_int32" + db_dir.mkdir() + db = Database(db_path=str(db_dir), mode="w") + conn = db.connect() + + conn.execute( + "CREATE NODE TABLE Sensor(" + " id INT64," + " readings INT32[3]," + " PRIMARY KEY(id)" + ");" + ) + + conn.execute( + "CREATE (s:Sensor {id: 1, readings: [10, 20, 30]});" + ) + conn.execute( + "CREATE (s:Sensor {id: 2, readings: [40, 50, 60]});" + ) + + result = conn.execute( + "MATCH (s:Sensor) WHERE s.id = 1 RETURN s.readings;" + ) + record = result.__next__() + assert list(record[0]) == [10, 20, 30] + + result = conn.execute( + "MATCH (s:Sensor) WHERE s.id = 2 RETURN s.readings;" + ) + record = result.__next__() + assert list(record[0]) == [40, 50, 60] + + conn.close() + db.close() + + +def test_array_double_create_and_query(tmp_path): + """Create a vertex type with DOUBLE[2] array property, insert and query.""" + db_dir = tmp_path / "array_double" + db_dir.mkdir() + db = Database(db_path=str(db_dir), mode="w") + conn = db.connect() + + conn.execute( + "CREATE NODE TABLE Vector(" + " id INT64," + " embedding DOUBLE[2]," + " PRIMARY KEY(id)" + ");" + ) + + conn.execute( + "CREATE (v:Vector {id: 1, embedding: [1.5, 2.5]});" + ) + + result = conn.execute( + "MATCH (v:Vector) WHERE v.id = 1 RETURN v.embedding;" + ) + record = result.__next__() + values = list(record[0]) + assert abs(values[0] - 1.5) < 1e-6 + assert abs(values[1] - 2.5) < 1e-6 + + conn.close() + db.close() + + +def test_array_update(tmp_path): + """Update an array property value.""" + db_dir = tmp_path / "array_update" + db_dir.mkdir() + db = Database(db_path=str(db_dir), mode="w") + conn = db.connect() + + conn.execute( + "CREATE NODE TABLE Point(" + " id INT64," + " coords INT64[2]," + " PRIMARY KEY(id)" + ");" + ) + + conn.execute( + "CREATE (p:Point {id: 1, coords: [100, 200]});" + ) + + conn.execute( + "MATCH (p:Point) WHERE p.id = 1 SET p.coords = [300, 400];" + ) + + result = conn.execute( + "MATCH (p:Point) WHERE p.id = 1 RETURN p.coords;" + ) + record = result.__next__() + assert list(record[0]) == [300, 400] + + conn.close() + db.close() + + +def test_array_multiple_rows(tmp_path): + """Insert and query multiple rows with array properties.""" + db_dir = tmp_path / "array_multi" + db_dir.mkdir() + db = Database(db_path=str(db_dir), mode="w") + conn = db.connect() + + conn.execute( + "CREATE NODE TABLE Item(" + " id INT64," + " features FLOAT[3]," + " PRIMARY KEY(id)" + ");" + ) + + for i in range(5): + vals = [float(i * 3 + j) for j in range(3)] + conn.execute( + f"CREATE (item:Item {{id: {i}, features: [{vals[0]}, {vals[1]}, {vals[2]}]}});" + ) + + result = conn.execute( + "MATCH (item:Item) RETURN item.id, item.features ORDER BY item.id;" + ) + rows = list(result) + assert len(rows) == 5 + for i, row in enumerate(rows): + expected = [float(i * 3 + j) for j in range(3)] + actual = list(row[1]) + for a, e in zip(actual, expected): + assert abs(a - e) < 1e-5, f"Row {i}: expected {e}, got {a}" + + conn.close() + db.close()